From adab4877adac732c8338eeb57ce6d6f90ab9974c Mon Sep 17 00:00:00 2001 From: renee <50965960+wurenee@users.noreply.github.com> Date: Fri, 30 Jan 2026 19:31:38 -0800 Subject: [PATCH] =?UTF-8?q?AI=E7=9A=84=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/ai/graph.py | 33 +++++++++++++ app/ai/nodes.py | 78 ++++++++++++++++++++++++++++++ app/ai/services/memory_service.py | 25 ++++++++++ app/ai/services/summary_service.py | 65 +++++++++++++++++++++++++ app/ai/state.py | 12 +++++ 5 files changed, 213 insertions(+) create mode 100644 app/ai/graph.py create mode 100644 app/ai/nodes.py create mode 100644 app/ai/services/memory_service.py create mode 100644 app/ai/services/summary_service.py create mode 100644 app/ai/state.py diff --git a/app/ai/graph.py b/app/ai/graph.py new file mode 100644 index 0000000..2189fd1 --- /dev/null +++ b/app/ai/graph.py @@ -0,0 +1,33 @@ +from langgraph.graph import StateGraph, END +from nodes import * +from state import State + +workflow = StateGraph(State) + +# 添加节点 +workflow.add_node("retrieve", retrieve_node) +workflow.add_node("summarize", summarize_node) +workflow.add_node("chatbot", call_model_node) + +# 设置入口 +workflow.set_entry_point("retrieve") + +# 条件逻辑:检查消息数量 +def should_summarize(state: State): + if len(state["messages"]) > 10: + return "summarize" + return "chatbot" + +workflow.add_conditional_edges( + "retrieve", + should_summarize, + { + "summarize": "summarize", + "chatbot": "chatbot" + } +) +workflow.add_edge("summarize", "chatbot") +workflow.add_edge("chatbot", END) + +# 编译时加入 Checkpointer (你可以使用你的 Postgres 实现来持久化 State) +app = workflow.compile() #checkpointer=postgres_checkpointer \ No newline at end of file diff --git a/app/ai/nodes.py b/app/ai/nodes.py new file mode 100644 index 0000000..919b303 --- /dev/null +++ b/app/ai/nodes.py @@ -0,0 +1,78 @@ +# nodes/graph_nodes.py +from services.memory_service import search_memories +from services.summary_service import get_rolling_summary +from langchain_core.messages import RemoveMessage +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_core.messages import SystemMessage, HumanMessage +from state import State + +async def retrieve_node(state: State): + # 只针对最后一条用户消息进行检索 + user_query = state["messages"][-1].content + memories = await search_memories(user_query, db_connection=None) + return {"retrieved_context": memories} + +async def smart_retrieve_node(state: State): + """ + 智能检索:先判断用户是否在提问需要背景的事情 + """ + last_msg = state["messages"][-1].content + + # 一个简单的判断逻辑,也可以用 LLM 做路由 + keywords = ["之前", "记得", "上次", "习惯", "喜欢", "谁", "哪"] + if any(k in last_msg for k in keywords): + # 执行向量检索 + memories = await search_memories(last_msg) + return {"retrieved_context": memories} + + return {"retrieved_context": ""} + + +async def summarize_node(state: State): + # 设定阈值,比如保留最后 6 条,剩下的全部压缩 + THRESHOLD = 10 + if len(state["messages"]) <= THRESHOLD: + return {} + + # 取出除最后 6 条以外的消息进行压缩 + to_summarize = state["messages"][:-6] + new_summary = await get_rolling_summary(model_flash, state.get("summary", ""), to_summarize) + + # 创建 RemoveMessage 列表来清理 State + delete_actions = [RemoveMessage(id=m.id) for m in to_summarize if m.id] + + return { + "summary": new_summary, + "messages": delete_actions + } + + + +# 初始化 Gemini (确保你已经设置了 GOOGLE_API_KEY) + +llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0.7, google_api_key=userdata.get('GOOGLE_API_KEY')) + +async def call_model_node(state: State): + """ + 这是最终生成对话的节点。 + 它负责拼接所有的上下文:Summary + Memory + Messages + """ + + # 1. 构建基础 System Prompt + system_content = "你是一个贴心的 AI 助手。" + + # 2. 注入长期摘要 (如果存在) + if state.get("summary"): + system_content += f"\n这是之前的对话简要背景:{state['summary']}" + + # 3. 注入检索到的按键记忆 (如果存在) + if state.get("retrieved_context"): + system_content += f"\n这是你记住的关于用户的重要事实:{state['retrieved_context']}" + + messages = [SystemMessage(content=system_content)] + state["messages"] + + # 4. 调用 Gemini + response = await llm.ainvoke(messages) + + # 返回更新后的消息列表 + return {"messages": [response]} \ No newline at end of file diff --git a/app/ai/services/memory_service.py b/app/ai/services/memory_service.py new file mode 100644 index 0000000..a30eafb --- /dev/null +++ b/app/ai/services/memory_service.py @@ -0,0 +1,25 @@ +# services/memory_service.py +from langchain_google_genai import GoogleGenerativeAIEmbeddings + +# 假设你使用的是 pgvector +async def search_memories(query: str, db_connection): + """ + 1. 将 query 转化为 Embedding + 2. 在数据库中执行向量相似度搜索 + 3. 返回最相关的 Top-K 条记忆 + """ + # 模拟实现 + embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") + query_vector = await embeddings.aembed_query(query) + + # 这里执行 SQL: SELECT content FROM memories ORDER BY embedding <=> query_vector LIMIT 3 + results = "用户此前提到过他在做 Gemini 相关的 Hackathon,倾向于使用 Python。" + return results + +async def save_to_memory(content: str, db_connection): + """ + 这个函数由你的 '保存' 按钮触发。 + """ + # 1. 提取 content 中的关键信息(可选,可以用 LLM 提取) + # 2. 生成 Embedding 并存入数据库 + pass \ No newline at end of file diff --git a/app/ai/services/summary_service.py b/app/ai/services/summary_service.py new file mode 100644 index 0000000..3bacc26 --- /dev/null +++ b/app/ai/services/summary_service.py @@ -0,0 +1,65 @@ +# services/summary_service.py +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_core.messages import HumanMessage, SystemMessage + +async def get_rolling_summary(model: ChatGoogleGenerativeAI, existing_summary: str, messages: list): + """ + 将旧的总结与新的对话内容合并生成新的总结 + """ + if not messages: + return existing_summary + + msg_content = "\n".join([f"{m.type}: {m.content}" for m in messages]) + + prompt = f""" + 你是一个记忆专家。请根据提供的“现有总结”和“新增对话”,生成一个更全面、精炼的新总结。 + 请保留关键事实(如技术偏好、重要决定、用户背景),删除无意义的寒暄。 + + [现有总结]: {existing_summary if existing_summary else "暂无"} + [新增对话]: {msg_content} + + 请直接输出新的总结文本,保持中文书写。 + """ + + response = await model.ainvoke([HumanMessage(content=prompt)]) + return response.content + +# services/memory_service.py + +async def extract_and_save_fact(thread_id: str, messages: list, db_connection): + """ + 由前端按钮触发:从当前对话上下文提取事实并存入向量库 + """ + # 1. 过滤掉无意义的消息,只取最近几条作为提取素材 + context_text = "\n".join([f"{m.type}: {m.content}" for m in messages[-10:]]) + + # 2. 调用小模型 (Flash) 进行原子化事实提取 + extraction_prompt = f""" + 从以下对话中提取用户提到的、具有长期保存价值的“个人事实”或“技术偏好”。 + 要求: + - 每一条事实必须是独立的、完整的句子。 + - 不要包含寒暄或临时性的讨论。 + - 如果没有值得记录的事实,请返回 "NONE"。 + + 对话内容: + {context_text} + + 输出格式示例: + - 用户正在使用 Python 3.12 进行开发。 + - 用户计划参加 2026 年的 Gemini Hackathon。 + """ + + # 这里假设你已经初始化了 model_flash + response = await model_flash.ainvoke(extraction_prompt) + facts_text = response.content.strip() + + if facts_text == "NONE": + return "没有发现值得记录的新事实。" + + # 3. 将提取到的事实转化为向量并存入 pgvector + # facts = facts_text.split('\n') + # for fact in facts: + # embedding = await get_embedding(fact) + # await db_connection.execute("INSERT INTO memories ...", embedding, fact, thread_id) + + return f"已成功记录以下记忆:\n{facts_text}" \ No newline at end of file diff --git a/app/ai/state.py b/app/ai/state.py new file mode 100644 index 0000000..cfe6a13 --- /dev/null +++ b/app/ai/state.py @@ -0,0 +1,12 @@ +from typing import Annotated, TypedDict +from langgraph.graph.message import add_messages + +class State(TypedDict): + # add_messages 会将新消息追加到列表,而不是覆盖 + messages: Annotated[list, add_messages] + # 存储当前的总结,避免重复加载大数据量历史 + summary: str + # 从 Long-term memory 检索到的事实 + retrieved_context: str + # 记录这轮对话是否触发了总结 + retrieved_context: str \ No newline at end of file