RE-TRAC / example.py
JialiangZhu's picture
init
3e26b02
import os
from pathlib import Path
from typing import Any, Dict, AsyncIterator, List, Tuple
import gradio as gr
import yaml
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
from retrac.graph import build_graph
def load_config(config_path: str) -> dict:
with open(config_path, "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
return config
async def stream_graph_states(
config_path: str,
question: str,
) -> AsyncIterator[Dict[str, Any]]:
cfg = load_config(config_path)
graph = build_graph(cfg)
compiled_graph = graph.compile()
recursion_limit = int(os.getenv("RECURSION_LIMIT", "10000"))
run_config = {"recursion_limit": recursion_limit}
state: Dict[str, Any] = {"question": question}
async for event in compiled_graph.astream(state, config=run_config):
for _, node_output in event.items():
if node_output is not None:
state.update(node_output)
yield state
def _message_role(msg: BaseMessage) -> str:
if isinstance(msg, HumanMessage):
return "user"
if isinstance(msg, AIMessage):
return "assistant"
if isinstance(msg, SystemMessage):
return "system"
if isinstance(msg, ToolMessage):
return "tool"
return getattr(msg, "type", "assistant")
def _message_content(msg: BaseMessage) -> str:
content = getattr(msg, "content", None)
if isinstance(content, str):
return content
if content is not None:
return str(content)
text = getattr(msg, "text", None)
return text if isinstance(text, str) else str(msg)
def _serialize_messages(messages: List[BaseMessage]) -> List[Dict[str, str]]:
serialized: List[Dict[str, str]] = []
for msg in messages:
role = _message_role(msg)
content = _message_content(msg)
if isinstance(msg, ToolMessage) and getattr(msg, "name", None):
content = f"[tool:{msg.name}] {content}"
serialized.append({"role": role, "content": content})
return serialized
def _messages_to_rows(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]:
rows: List[Tuple[str, str]] = []
for msg in messages:
rows.append((msg.get("role", ""), msg.get("content", "")))
return rows
def _split_incremental(
prev: List[Dict[str, str]],
current: List[Dict[str, str]],
) -> Tuple[bool, List[Dict[str, str]]]:
if not prev:
return False, current
if len(current) >= len(prev) and current[: len(prev)] == prev:
return False, current[len(prev) :]
return True, current
async def chat_once(
query_text: str,
prev_serialized: List[Dict[str, str]],
chat_history: List[Tuple[str, str]],
):
if not query_text or not query_text.strip():
yield chat_history or [], prev_serialized or [], gr.update(), gr.update()
return
query_text = query_text.strip()
prev_serialized = prev_serialized or []
chat_history = chat_history or []
disable_query = gr.update(value=query_text, interactive=False)
disable_submit = gr.update(interactive=False)
base_dir = Path(__file__).resolve().parent
config_path = str(base_dir / "retrac" / "30B.yaml")
async for state in stream_graph_states(config_path, query_text):
messages = state.get("messages")
if messages is None:
continue
current_serialized = _serialize_messages(messages)
reset, incremental = _split_incremental(prev_serialized, current_serialized)
if reset:
chat_history = _messages_to_rows(incremental)
else:
chat_history = list(chat_history) + _messages_to_rows(incremental)
prev_serialized = current_serialized
yield chat_history, prev_serialized, disable_query, disable_submit
def build_demo() -> gr.Blocks:
with gr.Blocks() as demo:
gr.Markdown("RE-TRAC Gradio UI")
chat = gr.Chatbot(label="Messages", height=520)
query = gr.Textbox(
label="Query",
placeholder="只允许输入一次查询",
lines=2,
)
submit = gr.Button("开始")
prev_serialized = gr.State([])
submit.click(
chat_once,
inputs=[query, prev_serialized, chat],
outputs=[chat, prev_serialized, query, submit],
)
query.submit(
chat_once,
inputs=[query, prev_serialized, chat],
outputs=[chat, prev_serialized, query, submit],
)
return demo
if __name__ == "__main__":
build_demo().queue().launch()