131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
|
||
!pip install -q langgraph-checkpoint-sqlite langchain_google_genai
|
||
|
||
import sqlite3
|
||
from google.colab import userdata
|
||
from typing import Literal, TypedDict, Annotated
|
||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage, AnyMessage
|
||
from langgraph.graph import StateGraph, START, END
|
||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||
from langgraph.graph.message import add_messages
|
||
from langchain_core.runnables import RunnableConfig
|
||
from typing import Union, List, Dict
|
||
|
||
|
||
# --- 1. 状态定义 ---
|
||
class State(TypedDict):
|
||
messages: Annotated[list[AnyMessage], add_messages]
|
||
summary: str # 永久存储在数据库中的摘要内容
|
||
|
||
# --- 2. 核心逻辑 ---
|
||
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature = 0.7, google_api_key='') #在这里倒入API key
|
||
|
||
def call_model(state: State, config: RunnableConfig):
|
||
"""对话节点:融合了动态 Prompt 和 长期摘要"""
|
||
|
||
# 获取当前 session 特有的 System Prompt(如果没传则使用默认)
|
||
configurable = config.get("configurable", {})
|
||
system_base_prompt = configurable.get("system_prompt", "你是一个通用的 AI 助手。")
|
||
|
||
# 构造当前上下文
|
||
summary = state.get("summary", "")
|
||
if summary:
|
||
system_base_prompt += f"\n\n<context_summary>\n{summary}\n</context_summary>"
|
||
|
||
messages = [SystemMessage(content=system_base_prompt)] + state["messages"]
|
||
response = llm.invoke(messages)
|
||
return {"messages": [response]}
|
||
|
||
def summarize_conversation(state: State):
|
||
"""总结节点:负责更新摘要并清理过期消息"""
|
||
summary = state.get("summary", "")
|
||
|
||
messages_to_summarize = state["messages"][:-1]
|
||
|
||
# If there's nothing to summarize yet, just END
|
||
if not messages_to_summarize:
|
||
return {"summary": summary}
|
||
|
||
system_prompt = (
|
||
"你是一个记忆管理专家。请更新摘要,合并新旧信息。"
|
||
"1. 保持简练,仅保留事实(姓名、偏好、核心议题)。"
|
||
"2. 如果新消息包含对旧信息的修正,请更新它。"
|
||
"3. 如果对话中包含图片描述,请将图片的关键视觉信息也记录在摘要中"
|
||
)
|
||
|
||
summary_input = f"现有摘要: {summary}\n\n待加入的新信息: {messages_to_summarize}"
|
||
|
||
# Invoke model to get new condensed summary
|
||
response = llm.invoke([
|
||
SystemMessage(content=system_prompt),
|
||
HumanMessage(content=summary_input)
|
||
])
|
||
|
||
# Important: Create RemoveMessage objects for all messages that were summarized
|
||
delete_messages = [RemoveMessage(id=m.id) for m in messages_to_summarize if m.id]
|
||
|
||
return {
|
||
"summary": response.content,
|
||
"messages": delete_messages
|
||
}
|
||
|
||
def should_continue(state: State) -> Literal["summarize", END]:
|
||
"""如果消息累积超过3条,则去总结节点"""
|
||
if len(state["messages"]) > 3: #changed
|
||
return "summarize"
|
||
return END
|
||
|
||
# --- 3. 构建图 ---
|
||
db_path = "multi_session_chat.sqlite"
|
||
conn = sqlite3.connect(db_path, check_same_thread=False)
|
||
memory = SqliteSaver(conn)
|
||
|
||
workflow = StateGraph(State)
|
||
workflow.add_node("chatbot", call_model)
|
||
workflow.add_node("summarize", summarize_conversation)
|
||
|
||
workflow.add_edge(START, "chatbot")
|
||
workflow.add_conditional_edges("chatbot", should_continue)
|
||
workflow.add_edge("summarize", END)
|
||
|
||
app = workflow.compile(checkpointer=memory)
|
||
|
||
def chat(thread_id: str, system_prompt: str, user_content: Union[str, List[Dict]]):
|
||
"""
|
||
Processes a single user message and returns the AI response,
|
||
persisting memory via the thread_id.
|
||
"""
|
||
config = {
|
||
"configurable": {
|
||
"thread_id": thread_id,
|
||
"system_prompt": system_prompt
|
||
}
|
||
}
|
||
|
||
# Prepare the input for this specific turn
|
||
input_data = {"messages": [HumanMessage(content=user_content)]}
|
||
|
||
ai_response = ""
|
||
|
||
# Stream the values to get the final AI message
|
||
for event in app.stream(input_data, config, stream_mode="values"):
|
||
if "messages" in event:
|
||
last_msg = event["messages"][-1]
|
||
if last_msg.type == "ai":
|
||
ai_response = last_msg.content
|
||
|
||
return ai_response
|
||
|
||
# 使用范例
|
||
# if __name__ == "__main__":
|
||
# tid = "py_expert_001"
|
||
# sys_p = "你是个善解人意的机器人。"
|
||
|
||
# # Call 1: Establish context
|
||
# resp1 = chat(tid, sys_p, "你好,我叫小明。")
|
||
# print(f"Bot: {resp1}")
|
||
|
||
# # Call 2: Test memory (The model should remember the name '小明')
|
||
# resp2 = chat(tid, sys_p, "我今天很开心")
|
||
# print(f"Bot: {resp2}") |