完成ai
This commit is contained in:
132
app/ai/ai.py
132
app/ai/ai.py
@@ -1,5 +1,13 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""gemini-hackathon ai.ipynb
|
||||||
|
|
||||||
# pip install -q langgraph-checkpoint-sqlite
|
Automatically generated by Colab.
|
||||||
|
|
||||||
|
Original file is located at
|
||||||
|
https://colab.research.google.com/drive/1FyV9Lq9Sxh_dFiaNIqeu1brOl8DAoKUO
|
||||||
|
"""
|
||||||
|
|
||||||
|
!pip install -q langgraph-checkpoint-sqlite langchain_google_genai
|
||||||
|
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from google.colab import userdata
|
from google.colab import userdata
|
||||||
@@ -17,7 +25,7 @@ class State(TypedDict):
|
|||||||
summary: str # 永久存储在数据库中的摘要内容
|
summary: str # 永久存储在数据库中的摘要内容
|
||||||
|
|
||||||
# --- 2. 核心逻辑 ---
|
# --- 2. 核心逻辑 ---
|
||||||
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature = 0.7, google_api_key='AIzaSyCfkiIgq4FmH5siBp3Iw6MRCml5zeSURnY')
|
llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature = 0.7, google_api_key='') #在这里倒入API key
|
||||||
|
|
||||||
def call_model(state: State, config: RunnableConfig):
|
def call_model(state: State, config: RunnableConfig):
|
||||||
"""对话节点:融合了动态 Prompt 和 长期摘要"""
|
"""对话节点:融合了动态 Prompt 和 长期摘要"""
|
||||||
@@ -27,11 +35,11 @@ def call_model(state: State, config: RunnableConfig):
|
|||||||
system_base_prompt = configurable.get("system_prompt", "你是一个通用的 AI 助手。")
|
system_base_prompt = configurable.get("system_prompt", "你是一个通用的 AI 助手。")
|
||||||
|
|
||||||
# 构造当前上下文
|
# 构造当前上下文
|
||||||
prompt = f"{system_base_prompt}"
|
summary = state.get("summary", "")
|
||||||
if state.get("summary"):
|
if summary:
|
||||||
prompt += f"\n\n[之前的对话核心摘要]: {state['summary']}"
|
system_base_prompt += f"\n\n<context_summary>\n{summary}\n</context_summary>"
|
||||||
|
|
||||||
messages = [SystemMessage(content=prompt)] + state["messages"]
|
messages = [SystemMessage(content=system_base_prompt)] + state["messages"]
|
||||||
response = llm.invoke(messages)
|
response = llm.invoke(messages)
|
||||||
return {"messages": [response]}
|
return {"messages": [response]}
|
||||||
|
|
||||||
@@ -39,24 +47,37 @@ def summarize_conversation(state: State):
|
|||||||
"""总结节点:负责更新摘要并清理过期消息"""
|
"""总结节点:负责更新摘要并清理过期消息"""
|
||||||
summary = state.get("summary", "")
|
summary = state.get("summary", "")
|
||||||
|
|
||||||
if summary:
|
messages_to_summarize = state["messages"][:-1]
|
||||||
summary_prompt = f"当前的摘要是: {summary}\n\n请结合最近的新消息,生成一份更新后的完整摘要,包含所有关键信息:"
|
|
||||||
else:
|
|
||||||
summary_prompt = "请总结目前的对话重点:"
|
|
||||||
|
|
||||||
# 获取除了最后两轮之外的所有消息进行总结
|
# If there's nothing to summarize yet, just END
|
||||||
messages_to_summarize = state["messages"][:-2]
|
if not messages_to_summarize:
|
||||||
content = [SystemMessage(content=summary_prompt)] + messages_to_summarize
|
return {"summary": summary}
|
||||||
response = llm.invoke(content)
|
|
||||||
|
|
||||||
# 生成删除指令,清除已总结过的消息 ID
|
system_prompt = (
|
||||||
delete_messages = [RemoveMessage(id=m.id) for m in messages_to_summarize]
|
"你是一个记忆管理专家。请更新摘要,合并新旧信息。"
|
||||||
|
"1. 保持简练,仅保留事实(姓名、偏好、核心议题)。"
|
||||||
|
"2. 如果新消息包含对旧信息的修正,请更新它。"
|
||||||
|
)
|
||||||
|
|
||||||
return {"summary": response.content, "messages": delete_messages}
|
summary_input = f"现有摘要: {summary}\n\n待加入的新信息: {messages_to_summarize}"
|
||||||
|
|
||||||
|
# Invoke model to get new condensed summary
|
||||||
|
response = llm.invoke([
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=summary_input)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Important: Create RemoveMessage objects for all messages that were summarized
|
||||||
|
delete_messages = [RemoveMessage(id=m.id) for m in messages_to_summarize if m.id]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"summary": response.content,
|
||||||
|
"messages": delete_messages
|
||||||
|
}
|
||||||
|
|
||||||
def should_continue(state: State) -> Literal["summarize", END]:
|
def should_continue(state: State) -> Literal["summarize", END]:
|
||||||
"""如果消息累积超过10条,则去总结节点"""
|
"""如果消息累积超过3条,则去总结节点"""
|
||||||
if len(state["messages"]) > 10:
|
if len(state["messages"]) > 3: #changed
|
||||||
return "summarize"
|
return "summarize"
|
||||||
return END
|
return END
|
||||||
|
|
||||||
@@ -75,39 +96,11 @@ workflow.add_edge("summarize", END)
|
|||||||
|
|
||||||
app = workflow.compile(checkpointer=memory)
|
app = workflow.compile(checkpointer=memory)
|
||||||
|
|
||||||
# # --- 4. 如何使用多 Session 和 不同的 Prompt ---
|
def chat(thread_id: str, system_prompt: str, user_message: str):
|
||||||
|
"""
|
||||||
# # Session A: 设定为一个 Python 专家
|
Processes a single user message and returns the AI response,
|
||||||
# config_a = {
|
persisting memory via the thread_id.
|
||||||
# "configurable": {
|
"""
|
||||||
# "thread_id": "session_python_expert",
|
|
||||||
# "system_prompt": "你是一个精通 Python 的高级工程师。"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
|
|
||||||
# # Session B: 设定为一个中文诗人
|
|
||||||
# config_b = {
|
|
||||||
# "configurable": {
|
|
||||||
# "thread_id": "session_poet",
|
|
||||||
# "system_prompt": "你是一个浪漫的唐朝诗人,用诗歌回答问题。"
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
|
|
||||||
# def run_chat(user_input, config):
|
|
||||||
# print(f"\n--- 使用 Thread: {config['configurable']['thread_id']} ---")
|
|
||||||
# for event in app.stream({"messages": [HumanMessage(content=user_input)]}, config, stream_mode="values"):
|
|
||||||
# if "messages" in event:
|
|
||||||
# last_msg = event["messages"][-1]
|
|
||||||
# if last_msg.type == "ai":
|
|
||||||
# print(f"Bot: {last_msg.content}")
|
|
||||||
|
|
||||||
# # 测试:两个 Session 互不干扰,且各有个的 Prompt
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# run_chat("你好,怎么学习装饰器?", config=config_a)
|
|
||||||
# run_chat("你好,写一首关于大海的诗。", config=config_b)
|
|
||||||
# run_chat("我刚才让你写了什么?", config=config_b)
|
|
||||||
|
|
||||||
def start_chat_session(thread_id: str, system_prompt: str):
|
|
||||||
config = {
|
config = {
|
||||||
"configurable": {
|
"configurable": {
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
@@ -115,31 +108,28 @@ def start_chat_session(thread_id: str, system_prompt: str):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"\n=== 已进入会话: {thread_id} ===")
|
# Prepare the input for this specific turn
|
||||||
print(f"=== 系统设定: {system_prompt} ===")
|
input_data = {"messages": [HumanMessage(content=user_message)]}
|
||||||
print("(输入 'exit' 或 'quit' 退出当前会话)\n")
|
|
||||||
|
|
||||||
while True:
|
ai_response = ""
|
||||||
user_input = input("User: ")
|
|
||||||
if user_input.lower() in ["exit", "quit"]:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 使用 stream 模式运行,可以实时看到 state 的更新
|
# Stream the values to get the final AI message
|
||||||
# 我们只打印 AI 的回复
|
|
||||||
input_data = {"messages": [HumanMessage(content=user_input)]}
|
|
||||||
for event in app.stream(input_data, config, stream_mode="values"):
|
for event in app.stream(input_data, config, stream_mode="values"):
|
||||||
if "messages" in event:
|
if "messages" in event:
|
||||||
last_msg = event["messages"][-1]
|
last_msg = event["messages"][-1]
|
||||||
if last_msg.type == "ai":
|
if last_msg.type == "ai":
|
||||||
print(f"Bot: {last_msg.content}")
|
ai_response = last_msg.content
|
||||||
|
|
||||||
|
return ai_response
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 模拟场景 1: Python 专家会话
|
tid = "py_expert_001"
|
||||||
# 即使你关掉程序再运行,只要 thread_id 还是 'py_expert_001',记忆就会从 sqlite 读取
|
sys_p = "你是个善解人意的机器人。"
|
||||||
start_chat_session(
|
|
||||||
thread_id="py_expert_001",
|
|
||||||
system_prompt="你是一个精通 Python 的架构师。"
|
|
||||||
)
|
|
||||||
|
|
||||||
# from IPython.display import Image, display
|
# Call 1: Establish context
|
||||||
# display(Image(app.get_graph().draw_mermaid_png()))
|
resp1 = chat(tid, sys_p, "你好,我叫小明。")
|
||||||
|
print(f"Bot: {resp1}")
|
||||||
|
|
||||||
|
# Call 2: Test memory (The model should remember the name '小明')
|
||||||
|
resp2 = chat(tid, sys_p, "我今天很开心")
|
||||||
|
print(f"Bot: {resp2}")
|
||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user