联入SQLite
This commit is contained in:
145
app/ai/ai.py
Normal file
145
app/ai/ai.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
|
||||||
|
# pip install -q langgraph-checkpoint-sqlite
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from google.colab import userdata
|
||||||
|
from typing import Literal, TypedDict, Annotated
|
||||||
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage, RemoveMessage, AnyMessage
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
|
||||||
|
# --- 1. 状态定义 ---
|
||||||
|
class State(TypedDict):
|
||||||
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
|
summary: str # 永久存储在数据库中的摘要内容
|
||||||
|
|
||||||
|
# --- 2. 核心逻辑 ---
|
||||||
|
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature = 0.7, google_api_key='AIzaSyCfkiIgq4FmH5siBp3Iw6MRCml5zeSURnY')
|
||||||
|
|
||||||
|
def call_model(state: State, config: RunnableConfig):
|
||||||
|
"""对话节点:融合了动态 Prompt 和 长期摘要"""
|
||||||
|
|
||||||
|
# 获取当前 session 特有的 System Prompt(如果没传则使用默认)
|
||||||
|
configurable = config.get("configurable", {})
|
||||||
|
system_base_prompt = configurable.get("system_prompt", "你是一个通用的 AI 助手。")
|
||||||
|
|
||||||
|
# 构造当前上下文
|
||||||
|
prompt = f"{system_base_prompt}"
|
||||||
|
if state.get("summary"):
|
||||||
|
prompt += f"\n\n[之前的对话核心摘要]: {state['summary']}"
|
||||||
|
|
||||||
|
messages = [SystemMessage(content=prompt)] + state["messages"]
|
||||||
|
response = llm.invoke(messages)
|
||||||
|
return {"messages": [response]}
|
||||||
|
|
||||||
|
def summarize_conversation(state: State):
|
||||||
|
"""总结节点:负责更新摘要并清理过期消息"""
|
||||||
|
summary = state.get("summary", "")
|
||||||
|
|
||||||
|
if summary:
|
||||||
|
summary_prompt = f"当前的摘要是: {summary}\n\n请结合最近的新消息,生成一份更新后的完整摘要,包含所有关键信息:"
|
||||||
|
else:
|
||||||
|
summary_prompt = "请总结目前的对话重点:"
|
||||||
|
|
||||||
|
# 获取除了最后两轮之外的所有消息进行总结
|
||||||
|
messages_to_summarize = state["messages"][:-2]
|
||||||
|
content = [SystemMessage(content=summary_prompt)] + messages_to_summarize
|
||||||
|
response = llm.invoke(content)
|
||||||
|
|
||||||
|
# 生成删除指令,清除已总结过的消息 ID
|
||||||
|
delete_messages = [RemoveMessage(id=m.id) for m in messages_to_summarize]
|
||||||
|
|
||||||
|
return {"summary": response.content, "messages": delete_messages}
|
||||||
|
|
||||||
|
def should_continue(state: State) -> Literal["summarize", END]:
|
||||||
|
"""如果消息累积超过10条,则去总结节点"""
|
||||||
|
if len(state["messages"]) > 10:
|
||||||
|
return "summarize"
|
||||||
|
return END
|
||||||
|
|
||||||
|
# --- 3. 构建图 ---
|
||||||
|
db_path = "multi_session_chat.sqlite"
|
||||||
|
conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
|
memory = SqliteSaver(conn)
|
||||||
|
|
||||||
|
workflow = StateGraph(State)
|
||||||
|
workflow.add_node("chatbot", call_model)
|
||||||
|
workflow.add_node("summarize", summarize_conversation)
|
||||||
|
|
||||||
|
workflow.add_edge(START, "chatbot")
|
||||||
|
workflow.add_conditional_edges("chatbot", should_continue)
|
||||||
|
workflow.add_edge("summarize", END)
|
||||||
|
|
||||||
|
app = workflow.compile(checkpointer=memory)
|
||||||
|
|
||||||
|
# # --- 4. 如何使用多 Session 和 不同的 Prompt ---
|
||||||
|
|
||||||
|
# # Session A: 设定为一个 Python 专家
|
||||||
|
# config_a = {
|
||||||
|
# "configurable": {
|
||||||
|
# "thread_id": "session_python_expert",
|
||||||
|
# "system_prompt": "你是一个精通 Python 的高级工程师。"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
# # Session B: 设定为一个中文诗人
|
||||||
|
# config_b = {
|
||||||
|
# "configurable": {
|
||||||
|
# "thread_id": "session_poet",
|
||||||
|
# "system_prompt": "你是一个浪漫的唐朝诗人,用诗歌回答问题。"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
|
||||||
|
# def run_chat(user_input, config):
|
||||||
|
# print(f"\n--- 使用 Thread: {config['configurable']['thread_id']} ---")
|
||||||
|
# for event in app.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
|
||||||
|
# if "messages" in event:
|
||||||
|
# last_msg = event["messages"][-1]
|
||||||
|
# if last_msg.type == "ai":
|
||||||
|
# print(f"Bot: {last_msg.content}")
|
||||||
|
|
||||||
|
# # 测试:两个 Session 互不干扰,且各有个的 Prompt
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# run_chat("你好,怎么学习装饰器?", config=config_a)
|
||||||
|
# run_chat("你好,写一首关于大海的诗。", config=config_b)
|
||||||
|
# run_chat("我刚才让你写了什么?", config=config_b)
|
||||||
|
|
||||||
|
def start_chat_session(thread_id: str, system_prompt: str):
|
||||||
|
config = {
|
||||||
|
"configurable": {
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"system_prompt": system_prompt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"\n=== 已进入会话: {thread_id} ===")
|
||||||
|
print(f"=== 系统设定: {system_prompt} ===")
|
||||||
|
print("(输入 'exit' 或 'quit' 退出当前会话)\n")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("User: ")
|
||||||
|
if user_input.lower() in ["exit", "quit"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 使用 stream 模式运行,可以实时看到 state 的更新
|
||||||
|
# 我们只打印 AI 的回复
|
||||||
|
input_data = {"messages": [HumanMessage(content=user_input)]}
|
||||||
|
for event in app.stream(input_data, config, stream_mode="values"):
|
||||||
|
if "messages" in event:
|
||||||
|
last_msg = event["messages"][-1]
|
||||||
|
if last_msg.type == "ai":
|
||||||
|
print(f"Bot: {last_msg.content}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 模拟场景 1: Python 专家会话
|
||||||
|
# 即使你关掉程序再运行,只要 thread_id 还是 'py_expert_001',记忆就会从 sqlite 读取
|
||||||
|
start_chat_session(
|
||||||
|
thread_id="py_expert_001",
|
||||||
|
system_prompt="你是一个精通 Python 的架构师。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# from IPython.display import Image, display
|
||||||
|
# display(Image(app.get_graph().draw_mermaid_png()))
|
||||||
1149
app/ai/ai_with_example_output.ipynb
Normal file
1149
app/ai/ai_with_example_output.ipynb
Normal file
File diff suppressed because one or more lines are too long
@@ -1,33 +0,0 @@
|
|||||||
from langgraph.graph import StateGraph, END
|
|
||||||
from nodes import *
|
|
||||||
from state import State
|
|
||||||
|
|
||||||
workflow = StateGraph(State)
|
|
||||||
|
|
||||||
# 添加节点
|
|
||||||
workflow.add_node("retrieve", retrieve_node)
|
|
||||||
workflow.add_node("summarize", summarize_node)
|
|
||||||
workflow.add_node("chatbot", call_model_node)
|
|
||||||
|
|
||||||
# 设置入口
|
|
||||||
workflow.set_entry_point("retrieve")
|
|
||||||
|
|
||||||
# 条件逻辑:检查消息数量
|
|
||||||
def should_summarize(state: State):
|
|
||||||
if len(state["messages"]) > 10:
|
|
||||||
return "summarize"
|
|
||||||
return "chatbot"
|
|
||||||
|
|
||||||
workflow.add_conditional_edges(
|
|
||||||
"retrieve",
|
|
||||||
should_summarize,
|
|
||||||
{
|
|
||||||
"summarize": "summarize",
|
|
||||||
"chatbot": "chatbot"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
workflow.add_edge("summarize", "chatbot")
|
|
||||||
workflow.add_edge("chatbot", END)
|
|
||||||
|
|
||||||
# 编译时加入 Checkpointer (你可以使用你的 Postgres 实现来持久化 State)
|
|
||||||
app = workflow.compile() #checkpointer=postgres_checkpointer
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
# nodes/graph_nodes.py
|
|
||||||
from services.memory_service import search_memories
|
|
||||||
from services.summary_service import get_rolling_summary
|
|
||||||
from langchain_core.messages import RemoveMessage
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from langchain_core.messages import SystemMessage, HumanMessage
|
|
||||||
from state import State
|
|
||||||
|
|
||||||
async def retrieve_node(state: State):
|
|
||||||
# 只针对最后一条用户消息进行检索
|
|
||||||
user_query = state["messages"][-1].content
|
|
||||||
memories = await search_memories(user_query, db_connection=None)
|
|
||||||
return {"retrieved_context": memories}
|
|
||||||
|
|
||||||
async def smart_retrieve_node(state: State):
|
|
||||||
"""
|
|
||||||
智能检索:先判断用户是否在提问需要背景的事情
|
|
||||||
"""
|
|
||||||
last_msg = state["messages"][-1].content
|
|
||||||
|
|
||||||
# 一个简单的判断逻辑,也可以用 LLM 做路由
|
|
||||||
keywords = ["之前", "记得", "上次", "习惯", "喜欢", "谁", "哪"]
|
|
||||||
if any(k in last_msg for k in keywords):
|
|
||||||
# 执行向量检索
|
|
||||||
memories = await search_memories(last_msg)
|
|
||||||
return {"retrieved_context": memories}
|
|
||||||
|
|
||||||
return {"retrieved_context": ""}
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_node(state: State):
|
|
||||||
# 设定阈值,比如保留最后 6 条,剩下的全部压缩
|
|
||||||
THRESHOLD = 10
|
|
||||||
if len(state["messages"]) <= THRESHOLD:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
# 取出除最后 6 条以外的消息进行压缩
|
|
||||||
to_summarize = state["messages"][:-6]
|
|
||||||
new_summary = await get_rolling_summary(model_flash, state.get("summary", ""), to_summarize)
|
|
||||||
|
|
||||||
# 创建 RemoveMessage 列表来清理 State
|
|
||||||
delete_actions = [RemoveMessage(id=m.id) for m in to_summarize if m.id]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"summary": new_summary,
|
|
||||||
"messages": delete_actions
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 初始化 Gemini (确保你已经设置了 GOOGLE_API_KEY)
|
|
||||||
|
|
||||||
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", temperature=0.7, google_api_key=userdata.get('GOOGLE_API_KEY'))
|
|
||||||
|
|
||||||
async def call_model_node(state: State):
|
|
||||||
"""
|
|
||||||
这是最终生成对话的节点。
|
|
||||||
它负责拼接所有的上下文:Summary + Memory + Messages
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 1. 构建基础 System Prompt
|
|
||||||
system_content = "你是一个贴心的 AI 助手。"
|
|
||||||
|
|
||||||
# 2. 注入长期摘要 (如果存在)
|
|
||||||
if state.get("summary"):
|
|
||||||
system_content += f"\n这是之前的对话简要背景:{state['summary']}"
|
|
||||||
|
|
||||||
# 3. 注入检索到的按键记忆 (如果存在)
|
|
||||||
if state.get("retrieved_context"):
|
|
||||||
system_content += f"\n这是你记住的关于用户的重要事实:{state['retrieved_context']}"
|
|
||||||
|
|
||||||
messages = [SystemMessage(content=system_content)] + state["messages"]
|
|
||||||
|
|
||||||
# 4. 调用 Gemini
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
# 返回更新后的消息列表
|
|
||||||
return {"messages": [response]}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
# services/memory_service.py
|
|
||||||
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
|
||||||
|
|
||||||
# 假设你使用的是 pgvector
|
|
||||||
async def search_memories(query: str, db_connection):
|
|
||||||
"""
|
|
||||||
1. 将 query 转化为 Embedding
|
|
||||||
2. 在数据库中执行向量相似度搜索
|
|
||||||
3. 返回最相关的 Top-K 条记忆
|
|
||||||
"""
|
|
||||||
# 模拟实现
|
|
||||||
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
|
|
||||||
query_vector = await embeddings.aembed_query(query)
|
|
||||||
|
|
||||||
# 这里执行 SQL: SELECT content FROM memories ORDER BY embedding <=> query_vector LIMIT 3
|
|
||||||
results = "用户此前提到过他在做 Gemini 相关的 Hackathon,倾向于使用 Python。"
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def save_to_memory(content: str, db_connection):
|
|
||||||
"""
|
|
||||||
这个函数由你的 '保存' 按钮触发。
|
|
||||||
"""
|
|
||||||
# 1. 提取 content 中的关键信息(可选,可以用 LLM 提取)
|
|
||||||
# 2. 生成 Embedding 并存入数据库
|
|
||||||
pass
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
# services/summary_service.py
|
|
||||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
async def get_rolling_summary(model: ChatGoogleGenerativeAI, existing_summary: str, messages: list):
|
|
||||||
"""
|
|
||||||
将旧的总结与新的对话内容合并生成新的总结
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return existing_summary
|
|
||||||
|
|
||||||
msg_content = "\n".join([f"{m.type}: {m.content}" for m in messages])
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
你是一个记忆专家。请根据提供的“现有总结”和“新增对话”,生成一个更全面、精炼的新总结。
|
|
||||||
请保留关键事实(如技术偏好、重要决定、用户背景),删除无意义的寒暄。
|
|
||||||
|
|
||||||
[现有总结]: {existing_summary if existing_summary else "暂无"}
|
|
||||||
[新增对话]: {msg_content}
|
|
||||||
|
|
||||||
请直接输出新的总结文本,保持中文书写。
|
|
||||||
"""
|
|
||||||
|
|
||||||
response = await model.ainvoke([HumanMessage(content=prompt)])
|
|
||||||
return response.content
|
|
||||||
|
|
||||||
# services/memory_service.py
|
|
||||||
|
|
||||||
async def extract_and_save_fact(thread_id: str, messages: list, db_connection):
|
|
||||||
"""
|
|
||||||
由前端按钮触发:从当前对话上下文提取事实并存入向量库
|
|
||||||
"""
|
|
||||||
# 1. 过滤掉无意义的消息,只取最近几条作为提取素材
|
|
||||||
context_text = "\n".join([f"{m.type}: {m.content}" for m in messages[-10:]])
|
|
||||||
|
|
||||||
# 2. 调用小模型 (Flash) 进行原子化事实提取
|
|
||||||
extraction_prompt = f"""
|
|
||||||
从以下对话中提取用户提到的、具有长期保存价值的“个人事实”或“技术偏好”。
|
|
||||||
要求:
|
|
||||||
- 每一条事实必须是独立的、完整的句子。
|
|
||||||
- 不要包含寒暄或临时性的讨论。
|
|
||||||
- 如果没有值得记录的事实,请返回 "NONE"。
|
|
||||||
|
|
||||||
对话内容:
|
|
||||||
{context_text}
|
|
||||||
|
|
||||||
输出格式示例:
|
|
||||||
- 用户正在使用 Python 3.12 进行开发。
|
|
||||||
- 用户计划参加 2026 年的 Gemini Hackathon。
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 这里假设你已经初始化了 model_flash
|
|
||||||
response = await model_flash.ainvoke(extraction_prompt)
|
|
||||||
facts_text = response.content.strip()
|
|
||||||
|
|
||||||
if facts_text == "NONE":
|
|
||||||
return "没有发现值得记录的新事实。"
|
|
||||||
|
|
||||||
# 3. 将提取到的事实转化为向量并存入 pgvector
|
|
||||||
# facts = facts_text.split('\n')
|
|
||||||
# for fact in facts:
|
|
||||||
# embedding = await get_embedding(fact)
|
|
||||||
# await db_connection.execute("INSERT INTO memories ...", embedding, fact, thread_id)
|
|
||||||
|
|
||||||
return f"已成功记录以下记忆:\n{facts_text}"
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
from typing import Annotated, TypedDict
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
|
|
||||||
class State(TypedDict):
|
|
||||||
# add_messages 会将新消息追加到列表,而不是覆盖
|
|
||||||
messages: Annotated[list, add_messages]
|
|
||||||
# 存储当前的总结,避免重复加载大数据量历史
|
|
||||||
summary: str
|
|
||||||
# 从 Long-term memory 检索到的事实
|
|
||||||
retrieved_context: str
|
|
||||||
# 记录这轮对话是否触发了总结
|
|
||||||
retrieved_context: str
|
|
||||||
Reference in New Issue
Block a user