AI的框架
This commit is contained in:
78
app/ai/nodes.py
Normal file
78
app/ai/nodes.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# nodes/graph_nodes.py
|
||||
from services.memory_service import search_memories
|
||||
from services.summary_service import get_rolling_summary
|
||||
from langchain_core.messages import RemoveMessage
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
from state import State
|
||||
|
||||
async def retrieve_node(state: State):
|
||||
# 只针对最后一条用户消息进行检索
|
||||
user_query = state["messages"][-1].content
|
||||
memories = await search_memories(user_query, db_connection=None)
|
||||
return {"retrieved_context": memories}
|
||||
|
||||
async def smart_retrieve_node(state: State):
|
||||
"""
|
||||
智能检索:先判断用户是否在提问需要背景的事情
|
||||
"""
|
||||
last_msg = state["messages"][-1].content
|
||||
|
||||
# 一个简单的判断逻辑,也可以用 LLM 做路由
|
||||
keywords = ["之前", "记得", "上次", "习惯", "喜欢", "谁", "哪"]
|
||||
if any(k in last_msg for k in keywords):
|
||||
# 执行向量检索
|
||||
memories = await search_memories(last_msg)
|
||||
return {"retrieved_context": memories}
|
||||
|
||||
return {"retrieved_context": ""}
|
||||
|
||||
|
||||
async def summarize_node(state: State):
|
||||
# 设定阈值,比如保留最后 6 条,剩下的全部压缩
|
||||
THRESHOLD = 10
|
||||
if len(state["messages"]) <= THRESHOLD:
|
||||
return {}
|
||||
|
||||
# 取出除最后 6 条以外的消息进行压缩
|
||||
to_summarize = state["messages"][:-6]
|
||||
new_summary = await get_rolling_summary(model_flash, state.get("summary", ""), to_summarize)
|
||||
|
||||
# 创建 RemoveMessage 列表来清理 State
|
||||
delete_actions = [RemoveMessage(id=m.id) for m in to_summarize if m.id]
|
||||
|
||||
return {
|
||||
"summary": new_summary,
|
||||
"messages": delete_actions
|
||||
}
|
||||
|
||||
|
||||
|
||||
# 初始化 Gemini (确保你已经设置了 GOOGLE_API_KEY)
|
||||
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0.7, google_api_key=userdata.get('GOOGLE_API_KEY'))
|
||||
|
||||
async def call_model_node(state: State):
|
||||
"""
|
||||
这是最终生成对话的节点。
|
||||
它负责拼接所有的上下文:Summary + Memory + Messages
|
||||
"""
|
||||
|
||||
# 1. 构建基础 System Prompt
|
||||
system_content = "你是一个贴心的 AI 助手。"
|
||||
|
||||
# 2. 注入长期摘要 (如果存在)
|
||||
if state.get("summary"):
|
||||
system_content += f"\n这是之前的对话简要背景:{state['summary']}"
|
||||
|
||||
# 3. 注入检索到的按键记忆 (如果存在)
|
||||
if state.get("retrieved_context"):
|
||||
system_content += f"\n这是你记住的关于用户的重要事实:{state['retrieved_context']}"
|
||||
|
||||
messages = [SystemMessage(content=system_content)] + state["messages"]
|
||||
|
||||
# 4. 调用 Gemini
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
# 返回更新后的消息列表
|
||||
return {"messages": [response]}
|
||||
Reference in New Issue
Block a user