Spaces:
Sleeping
Sleeping
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, AsyncIterator, List, Tuple | |
| import gradio as gr | |
| import yaml | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage | |
| from retrac.graph import build_graph | |
| def load_config(config_path: str) -> dict: | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| async def stream_graph_states( | |
| config_path: str, | |
| question: str, | |
| ) -> AsyncIterator[Dict[str, Any]]: | |
| cfg = load_config(config_path) | |
| graph = build_graph(cfg) | |
| compiled_graph = graph.compile() | |
| recursion_limit = int(os.getenv("RECURSION_LIMIT", "10000")) | |
| run_config = {"recursion_limit": recursion_limit} | |
| state: Dict[str, Any] = {"question": question} | |
| async for event in compiled_graph.astream(state, config=run_config): | |
| for _, node_output in event.items(): | |
| if node_output is not None: | |
| state.update(node_output) | |
| yield state | |
| def _message_role(msg: BaseMessage) -> str: | |
| if isinstance(msg, HumanMessage): | |
| return "user" | |
| if isinstance(msg, AIMessage): | |
| return "assistant" | |
| if isinstance(msg, SystemMessage): | |
| return "system" | |
| if isinstance(msg, ToolMessage): | |
| return "tool" | |
| return getattr(msg, "type", "assistant") | |
| def _message_content(msg: BaseMessage) -> str: | |
| content = getattr(msg, "content", None) | |
| if isinstance(content, str): | |
| return content | |
| if content is not None: | |
| return str(content) | |
| text = getattr(msg, "text", None) | |
| return text if isinstance(text, str) else str(msg) | |
| def _serialize_messages(messages: List[BaseMessage]) -> List[Dict[str, str]]: | |
| serialized: List[Dict[str, str]] = [] | |
| for msg in messages: | |
| role = _message_role(msg) | |
| content = _message_content(msg) | |
| if isinstance(msg, ToolMessage) and getattr(msg, "name", None): | |
| content = f"[tool:{msg.name}] {content}" | |
| serialized.append({"role": role, "content": content}) | |
| return serialized | |
| def _messages_to_rows(messages: List[Dict[str, str]]) -> List[Tuple[str, str]]: | |
| rows: List[Tuple[str, str]] = [] | |
| for msg in messages: | |
| rows.append((msg.get("role", ""), msg.get("content", ""))) | |
| return rows | |
| def _split_incremental( | |
| prev: List[Dict[str, str]], | |
| current: List[Dict[str, str]], | |
| ) -> Tuple[bool, List[Dict[str, str]]]: | |
| if not prev: | |
| return False, current | |
| if len(current) >= len(prev) and current[: len(prev)] == prev: | |
| return False, current[len(prev) :] | |
| return True, current | |
| async def chat_once( | |
| query_text: str, | |
| prev_serialized: List[Dict[str, str]], | |
| chat_history: List[Tuple[str, str]], | |
| ): | |
| if not query_text or not query_text.strip(): | |
| yield chat_history or [], prev_serialized or [], gr.update(), gr.update() | |
| return | |
| query_text = query_text.strip() | |
| prev_serialized = prev_serialized or [] | |
| chat_history = chat_history or [] | |
| disable_query = gr.update(value=query_text, interactive=False) | |
| disable_submit = gr.update(interactive=False) | |
| base_dir = Path(__file__).resolve().parent | |
| config_path = str(base_dir / "retrac" / "30B.yaml") | |
| async for state in stream_graph_states(config_path, query_text): | |
| messages = state.get("messages") | |
| if messages is None: | |
| continue | |
| current_serialized = _serialize_messages(messages) | |
| reset, incremental = _split_incremental(prev_serialized, current_serialized) | |
| if reset: | |
| chat_history = _messages_to_rows(incremental) | |
| else: | |
| chat_history = list(chat_history) + _messages_to_rows(incremental) | |
| prev_serialized = current_serialized | |
| yield chat_history, prev_serialized, disable_query, disable_submit | |
| def build_demo() -> gr.Blocks: | |
| with gr.Blocks() as demo: | |
| gr.Markdown("RE-TRAC Gradio UI") | |
| chat = gr.Chatbot(label="Messages", height=520) | |
| query = gr.Textbox( | |
| label="Query", | |
| placeholder="只允许输入一次查询", | |
| lines=2, | |
| ) | |
| submit = gr.Button("开始") | |
| prev_serialized = gr.State([]) | |
| submit.click( | |
| chat_once, | |
| inputs=[query, prev_serialized, chat], | |
| outputs=[chat, prev_serialized, query, submit], | |
| ) | |
| query.submit( | |
| chat_once, | |
| inputs=[query, prev_serialized, chat], | |
| outputs=[chat, prev_serialized, query, submit], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_demo().queue().launch() | |