33 lines
859 B
Python
33 lines
859 B
Python
from langgraph.graph import StateGraph, END
|
|
from nodes import *
|
|
from state import State
|
|
|
|
workflow = StateGraph(State)
|
|
|
|
# 添加节点
|
|
workflow.add_node("retrieve", retrieve_node)
|
|
workflow.add_node("summarize", summarize_node)
|
|
workflow.add_node("chatbot", call_model_node)
|
|
|
|
# 设置入口
|
|
workflow.set_entry_point("retrieve")
|
|
|
|
# 条件逻辑:检查消息数量
|
|
def should_summarize(state: State):
|
|
if len(state["messages"]) > 10:
|
|
return "summarize"
|
|
return "chatbot"
|
|
|
|
workflow.add_conditional_edges(
|
|
"retrieve",
|
|
should_summarize,
|
|
{
|
|
"summarize": "summarize",
|
|
"chatbot": "chatbot"
|
|
}
|
|
)
|
|
workflow.add_edge("summarize", "chatbot")
|
|
workflow.add_edge("chatbot", END)
|
|
|
|
# 编译时加入 Checkpointer (你可以使用你的 Postgres 实现来持久化 State)
|
|
app = workflow.compile() #checkpointer=postgres_checkpointer |