AI的框架

This commit is contained in:
renee
2026-01-30 19:31:38 -08:00
parent 4b5b6fb976
commit adab4877ad
5 changed files with 213 additions and 0 deletions

33
app/ai/graph.py Normal file
View File

@@ -0,0 +1,33 @@
from langgraph.graph import StateGraph, END
from nodes import *
from state import State
workflow = StateGraph(State)
# 添加节点
workflow.add_node("retrieve", retrieve_node)
workflow.add_node("summarize", summarize_node)
workflow.add_node("chatbot", call_model_node)
# 设置入口
workflow.set_entry_point("retrieve")
# 条件逻辑:检查消息数量
def should_summarize(state: State):
if len(state["messages"]) > 10:
return "summarize"
return "chatbot"
workflow.add_conditional_edges(
"retrieve",
should_summarize,
{
"summarize": "summarize",
"chatbot": "chatbot"
}
)
workflow.add_edge("summarize", "chatbot")
workflow.add_edge("chatbot", END)
# 编译时加入 Checkpointer (你可以使用你的 Postgres 实现来持久化 State)
app = workflow.compile() #checkpointer=postgres_checkpointer

78
app/ai/nodes.py Normal file
View File

@@ -0,0 +1,78 @@
# nodes/graph_nodes.py
from services.memory_service import search_memories
from services.summary_service import get_rolling_summary
from langchain_core.messages import RemoveMessage
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import SystemMessage, HumanMessage
from state import State
async def retrieve_node(state: State):
# 只针对最后一条用户消息进行检索
user_query = state["messages"][-1].content
memories = await search_memories(user_query, db_connection=None)
return {"retrieved_context": memories}
async def smart_retrieve_node(state: State):
"""
智能检索:先判断用户是否在提问需要背景的事情
"""
last_msg = state["messages"][-1].content
# 一个简单的判断逻辑,也可以用 LLM 做路由
keywords = ["之前", "记得", "上次", "习惯", "喜欢", "", ""]
if any(k in last_msg for k in keywords):
# 执行向量检索
memories = await search_memories(last_msg)
return {"retrieved_context": memories}
return {"retrieved_context": ""}
async def summarize_node(state: State):
# 设定阈值,比如保留最后 6 条,剩下的全部压缩
THRESHOLD = 10
if len(state["messages"]) <= THRESHOLD:
return {}
# 取出除最后 6 条以外的消息进行压缩
to_summarize = state["messages"][:-6]
new_summary = await get_rolling_summary(model_flash, state.get("summary", ""), to_summarize)
# 创建 RemoveMessage 列表来清理 State
delete_actions = [RemoveMessage(id=m.id) for m in to_summarize if m.id]
return {
"summary": new_summary,
"messages": delete_actions
}
# 初始化 Gemini (确保你已经设置了 GOOGLE_API_KEY)
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0.7, google_api_key=userdata.get('GOOGLE_API_KEY'))
async def call_model_node(state: State):
"""
这是最终生成对话的节点。
它负责拼接所有的上下文Summary + Memory + Messages
"""
# 1. 构建基础 System Prompt
system_content = "你是一个贴心的 AI 助手。"
# 2. 注入长期摘要 (如果存在)
if state.get("summary"):
system_content += f"\n这是之前的对话简要背景:{state['summary']}"
# 3. 注入检索到的按键记忆 (如果存在)
if state.get("retrieved_context"):
system_content += f"\n这是你记住的关于用户的重要事实:{state['retrieved_context']}"
messages = [SystemMessage(content=system_content)] + state["messages"]
# 4. 调用 Gemini
response = await llm.ainvoke(messages)
# 返回更新后的消息列表
return {"messages": [response]}

View File

@@ -0,0 +1,25 @@
# services/memory_service.py
from langchain_google_genai import GoogleGenerativeAIEmbeddings
# 假设你使用的是 pgvector
async def search_memories(query: str, db_connection):
"""
1. 将 query 转化为 Embedding
2. 在数据库中执行向量相似度搜索
3. 返回最相关的 Top-K 条记忆
"""
# 模拟实现
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
query_vector = await embeddings.aembed_query(query)
# 这里执行 SQL: SELECT content FROM memories ORDER BY embedding <=> query_vector LIMIT 3
results = "用户此前提到过他在做 Gemini 相关的 Hackathon倾向于使用 Python。"
return results
async def save_to_memory(content: str, db_connection):
"""
这个函数由你的 '保存' 按钮触发。
"""
# 1. 提取 content 中的关键信息(可选,可以用 LLM 提取)
# 2. 生成 Embedding 并存入数据库
pass

View File

@@ -0,0 +1,65 @@
# services/summary_service.py
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage
async def get_rolling_summary(model: ChatGoogleGenerativeAI, existing_summary: str, messages: list):
"""
将旧的总结与新的对话内容合并生成新的总结
"""
if not messages:
return existing_summary
msg_content = "\n".join([f"{m.type}: {m.content}" for m in messages])
prompt = f"""
你是一个记忆专家。请根据提供的“现有总结”和“新增对话”,生成一个更全面、精炼的新总结。
请保留关键事实(如技术偏好、重要决定、用户背景),删除无意义的寒暄。
[现有总结]: {existing_summary if existing_summary else "暂无"}
[新增对话]: {msg_content}
请直接输出新的总结文本,保持中文书写。
"""
response = await model.ainvoke([HumanMessage(content=prompt)])
return response.content
# services/memory_service.py
async def extract_and_save_fact(thread_id: str, messages: list, db_connection):
"""
由前端按钮触发:从当前对话上下文提取事实并存入向量库
"""
# 1. 过滤掉无意义的消息,只取最近几条作为提取素材
context_text = "\n".join([f"{m.type}: {m.content}" for m in messages[-10:]])
# 2. 调用小模型 (Flash) 进行原子化事实提取
extraction_prompt = f"""
从以下对话中提取用户提到的、具有长期保存价值的“个人事实”或“技术偏好”。
要求:
- 每一条事实必须是独立的、完整的句子。
- 不要包含寒暄或临时性的讨论。
- 如果没有值得记录的事实,请返回 "NONE"
对话内容:
{context_text}
输出格式示例:
- 用户正在使用 Python 3.12 进行开发。
- 用户计划参加 2026 年的 Gemini Hackathon。
"""
# 这里假设你已经初始化了 model_flash
response = await model_flash.ainvoke(extraction_prompt)
facts_text = response.content.strip()
if facts_text == "NONE":
return "没有发现值得记录的新事实。"
# 3. 将提取到的事实转化为向量并存入 pgvector
# facts = facts_text.split('\n')
# for fact in facts:
# embedding = await get_embedding(fact)
# await db_connection.execute("INSERT INTO memories ...", embedding, fact, thread_id)
return f"已成功记录以下记忆:\n{facts_text}"

12
app/ai/state.py Normal file
View File

@@ -0,0 +1,12 @@
from typing import Annotated, TypedDict
from langgraph.graph.message import add_messages
class State(TypedDict):
# add_messages 会将新消息追加到列表,而不是覆盖
messages: Annotated[list, add_messages]
# 存储当前的总结,避免重复加载大数据量历史
summary: str
# 从 Long-term memory 检索到的事实
retrieved_context: str
# 记录这轮对话是否触发了总结
retrieved_context: str