Files
backend/app/ai/ai.py
2026-01-30 22:49:12 -08:00

145 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# pip install -q langgraph-checkpoint-sqlite
import sqlite3
from google.colab import userdata
from typing import Literal, TypedDict, Annotated
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage, AnyMessage
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph.message import add_messages
from langchain_core.runnables import RunnableConfig
# --- 1. 状态定义 ---
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
summary: str # 永久存储在数据库中的摘要内容
# --- 2. 核心逻辑 ---
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature = 0.7, google_api_key='AIzaSyCfkiIgq4FmH5siBp3Iw6MRCml5zeSURnY')
def call_model(state: State, config: RunnableConfig):
"""对话节点:融合了动态 Prompt 和 长期摘要"""
# 获取当前 session 特有的 System Prompt如果没传则使用默认
configurable = config.get("configurable", {})
system_base_prompt = configurable.get("system_prompt", "你是一个通用的 AI 助手。")
# 构造当前上下文
prompt = f"{system_base_prompt}"
if state.get("summary"):
prompt += f"\n\n[之前的对话核心摘要]: {state['summary']}"
messages = [SystemMessage(content=prompt)] + state["messages"]
response = llm.invoke(messages)
return {"messages": [response]}
def summarize_conversation(state: State):
"""总结节点:负责更新摘要并清理过期消息"""
summary = state.get("summary", "")
if summary:
summary_prompt = f"当前的摘要是: {summary}\n\n请结合最近的新消息,生成一份更新后的完整摘要,包含所有关键信息:"
else:
summary_prompt = "请总结目前的对话重点:"
# 获取除了最后两轮之外的所有消息进行总结
messages_to_summarize = state["messages"][:-2]
content = [SystemMessage(content=summary_prompt)] + messages_to_summarize
response = llm.invoke(content)
# 生成删除指令,清除已总结过的消息 ID
delete_messages = [RemoveMessage(id=m.id) for m in messages_to_summarize]
return {"summary": response.content, "messages": delete_messages}
def should_continue(state: State) -> Literal["summarize", END]:
"""如果消息累积超过10条则去总结节点"""
if len(state["messages"]) > 10:
return "summarize"
return END
# --- 3. 构建图 ---
db_path = "multi_session_chat.sqlite"
conn = sqlite3.connect(db_path, check_same_thread=False)
memory = SqliteSaver(conn)
workflow = StateGraph(State)
workflow.add_node("chatbot", call_model)
workflow.add_node("summarize", summarize_conversation)
workflow.add_edge(START, "chatbot")
workflow.add_conditional_edges("chatbot", should_continue)
workflow.add_edge("summarize", END)
app = workflow.compile(checkpointer=memory)
# # --- 4. 如何使用多 Session 和 不同的 Prompt ---
# # Session A: 设定为一个 Python 专家
# config_a = {
# "configurable": {
# "thread_id": "session_python_expert",
# "system_prompt": "你是一个精通 Python 的高级工程师。"
# }
# }
# # Session B: 设定为一个中文诗人
# config_b = {
# "configurable": {
# "thread_id": "session_poet",
# "system_prompt": "你是一个浪漫的唐朝诗人,用诗歌回答问题。"
# }
# }
# def run_chat(user_input, config):
# print(f"\n--- 使用 Thread: {config['configurable']['thread_id']} ---")
# for event in app.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
# if "messages" in event:
# last_msg = event["messages"][-1]
# if last_msg.type == "ai":
# print(f"Bot: {last_msg.content}")
# # 测试:两个 Session 互不干扰,且各有个的 Prompt
# if __name__ == "__main__":
# run_chat("你好,怎么学习装饰器?", config=config_a)
# run_chat("你好,写一首关于大海的诗。", config=config_b)
# run_chat("我刚才让你写了什么?", config=config_b)
def start_chat_session(thread_id: str, system_prompt: str):
config = {
"configurable": {
"thread_id": thread_id,
"system_prompt": system_prompt
}
}
print(f"\n=== 已进入会话: {thread_id} ===")
print(f"=== 系统设定: {system_prompt} ===")
print("(输入 'exit''quit' 退出当前会话)\n")
while True:
user_input = input("User: ")
if user_input.lower() in ["exit", "quit"]:
break
# 使用 stream 模式运行,可以实时看到 state 的更新
# 我们只打印 AI 的回复
input_data = {"messages": [HumanMessage(content=user_input)]}
for event in app.stream(input_data, config, stream_mode="values"):
if "messages" in event:
last_msg = event["messages"][-1]
if last_msg.type == "ai":
print(f"Bot: {last_msg.content}")
if __name__ == "__main__":
# 模拟场景 1: Python 专家会话
# 即使你关掉程序再运行,只要 thread_id 还是 'py_expert_001',记忆就会从 sqlite 读取
start_chat_session(
thread_id="py_expert_001",
system_prompt="你是一个精通 Python 的架构师。"
)
# from IPython.display import Image, display
# display(Image(app.get_graph().draw_mermaid_png()))