Files
backend/app/ai/ai.py
2026-02-02 12:33:57 -08:00

135 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""gemini-hackathon ai.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1FyV9Lq9Sxh_dFiaNIqeu1brOl8DAoKUO
"""
!pip install -q langgraph-checkpoint-sqlite langchain_google_genai
import sqlite3
from google.colab import userdata
from typing import Literal, TypedDict, Annotated
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage, AnyMessage
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from langchain_core.runnables import RunnableConfig
# --- 1. 状态定义 ---
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
summary: str # 永久存储在数据库中的摘要内容
# --- 2. 核心逻辑 ---
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature = 0.7, google_api_key='') #在这里倒入API key
def call_model(state: State, config: RunnableConfig):
"""对话节点:融合了动态 Prompt 和 长期摘要"""
# 获取当前 session 特有的 System Prompt如果没传则使用默认
configurable = config.get("configurable", {})
system_base_prompt = configurable.get("system_prompt", "你是一个通用的 AI 助手。")
# 构造当前上下文
summary = state.get("summary", "")
if summary:
system_base_prompt += f"\n\n<context_summary>\n{summary}\n</context_summary>"
messages = [SystemMessage(content=system_base_prompt)] + state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
def summarize_conversation(state: State):
"""总结节点:负责更新摘要并清理过期消息"""
summary = state.get("summary", "")
messages_to_summarize = state["messages"][:-1]
# If there's nothing to summarize yet, just END
if not messages_to_summarize:
return {"summary": summary}
system_prompt = (
"你是一个记忆管理专家。请更新摘要,合并新旧信息。"
"1. 保持简练,仅保留事实(姓名、偏好、核心议题)。"
"2. 如果新消息包含对旧信息的修正,请更新它。"
)
summary_input = f"现有摘要: {summary}\n\n待加入的新信息: {messages_to_summarize}"
# Invoke model to get new condensed summary
response = llm.invoke([
SystemMessage(content=system_prompt),
HumanMessage(content=summary_input)
])
# Important: Create RemoveMessage objects for all messages that were summarized
delete_messages = [RemoveMessage(id=m.id) for m in messages_to_summarize if m.id]
return {
"summary": response.content,
"messages": delete_messages
}
def should_continue(state: State) -> Literal["summarize", END]:
"""如果消息累积超过3条则去总结节点"""
if len(state["messages"]) > 3: #changed
return "summarize"
return END
# --- 3. 构建图 ---
db_path = "multi_session_chat.sqlite"
conn = sqlite3.connect(db_path, check_same_thread=False)
memory = SqliteSaver(conn)
workflow = StateGraph(State)
workflow.add_node("chatbot", call_model)
workflow.add_node("summarize", summarize_conversation)
workflow.add_edge(START, "chatbot")
workflow.add_conditional_edges("chatbot", should_continue)
workflow.add_edge("summarize", END)
app = workflow.compile(checkpointer=memory)
def chat(thread_id: str, system_prompt: str, user_message: str):
"""
Processes a single user message and returns the AI response,
persisting memory via the thread_id.
"""
config = {
"configurable": {
"thread_id": thread_id,
"system_prompt": system_prompt
}
}
# Prepare the input for this specific turn
input_data = {"messages": [HumanMessage(content=user_message)]}
ai_response = ""
# Stream the values to get the final AI message
for event in app.stream(input_data, config, stream_mode="values"):
if "messages" in event:
last_msg = event["messages"][-1]
if last_msg.type == "ai":
ai_response = last_msg.content
return ai_response
if __name__ == "__main__":
tid = "py_expert_001"
sys_p = "你是个善解人意的机器人。"
# Call 1: Establish context
resp1 = chat(tid, sys_p, "你好,我叫小明。")
print(f"Bot: {resp1}")
# Call 2: Test memory (The model should remember the name '小明')
resp2 = chat(tid, sys_p, "我今天很开心")
print(f"Bot: {resp2}")