Files
backend/app/ai/nodes.py
2026-01-30 19:31:38 -08:00

78 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# nodes/graph_nodes.py
from services.memory_service import search_memories
from services.summary_service import get_rolling_summary
from langchain_core.messages import RemoveMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import SystemMessage, HumanMessage
from state import State
async def retrieve_node(state: State):
# 只针对最后一条用户消息进行检索
user_query = state["messages"][-1].content
memories = await search_memories(user_query, db_connection=None)
return {"retrieved_context": memories}
async def smart_retrieve_node(state: State):
"""
智能检索:先判断用户是否在提问需要背景的事情
"""
last_msg = state["messages"][-1].content
# 一个简单的判断逻辑,也可以用 LLM 做路由
keywords = ["之前", "记得", "上次", "习惯", "喜欢", "", ""]
if any(k in last_msg for k in keywords):
# 执行向量检索
memories = await search_memories(last_msg)
return {"retrieved_context": memories}
return {"retrieved_context": ""}
async def summarize_node(state: State):
# 设定阈值,比如保留最后 6 条,剩下的全部压缩
THRESHOLD = 10
if len(state["messages"]) <= THRESHOLD:
return {}
# 取出除最后 6 条以外的消息进行压缩
to_summarize = state["messages"][:-6]
new_summary = await get_rolling_summary(model_flash, state.get("summary", ""), to_summarize)
# 创建 RemoveMessage 列表来清理 State
delete_actions = [RemoveMessage(id=m.id) for m in to_summarize if m.id]
return {
"summary": new_summary,
"messages": delete_actions
}
# 初始化 Gemini (确保你已经设置了 GOOGLE_API_KEY)
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0.7, google_api_key=userdata.get('GOOGLE_API_KEY'))
async def call_model_node(state: State):
"""
这是最终生成对话的节点。
它负责拼接所有的上下文Summary + Memory + Messages
"""
# 1. 构建基础 System Prompt
system_content = "你是一个贴心的 AI 助手。"
# 2. 注入长期摘要 (如果存在)
if state.get("summary"):
system_content += f"\n这是之前的对话简要背景:{state['summary']}"
# 3. 注入检索到的按键记忆 (如果存在)
if state.get("retrieved_context"):
system_content += f"\n这是你记住的关于用户的重要事实:{state['retrieved_context']}"
messages = [SystemMessage(content=system_content)] + state["messages"]
# 4. 调用 Gemini
response = await llm.ainvoke(messages)
# 返回更新后的消息列表
return {"messages": [response]}