Spaces:
Sleeping
Sleeping
| """ | |
| ConversationSummaryBufferMemory — Qdrant-backed. | |
| Giới hạn buffer tính theo số thành viên room (n): | |
| MAX_BUFFER = 10n | |
| SUMMARIZE_COUNT = 6n (số tin cũ nhất được tóm tắt khi vượt ngưỡng) | |
| KEEP_RECENT = 4n (số tin giữ lại trong buffer sau khi tóm tắt) | |
| Fallback n = 20 nếu không kết nối được Supabase hoặc không phải room. | |
| """ | |
| import logging | |
| import uuid | |
| from datetime import datetime, timezone, timedelta | |
| from typing import Optional | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, PointStruct, VectorParams | |
| from src.config import QDRANT_API_KEY, QDRANT_URL, SUPABASE_SERVICE_ROLE_KEY, SUPABASE_URL | |
| from src.llm import llm | |
| from src.tools.utils import format_created_at | |
| logger = logging.getLogger(__name__) | |
| _DEFAULT_N = 20 | |
| _COLLECTION = "conversation_memory" | |
| _DUMMY_VECTOR = [0.0] | |
| _qdrant_client: Optional[QdrantClient] = None | |
| _sb_client = None | |
| # ── Qdrant client ───────────────────────────────────────────────────────────── | |
| def _get_qdrant() -> QdrantClient: | |
| global _qdrant_client | |
| if _qdrant_client is None: | |
| _qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY) | |
| _ensure_collection(_qdrant_client) | |
| return _qdrant_client | |
| def _ensure_collection(client: QdrantClient) -> None: | |
| existing = {c.name for c in client.get_collections().collections} | |
| if _COLLECTION not in existing: | |
| client.create_collection( | |
| collection_name=_COLLECTION, | |
| vectors_config=VectorParams(size=1, distance=Distance.DOT), | |
| ) | |
| logger.info("Qdrant: collection '%s' created.", _COLLECTION) | |
| def _point_id(conversation_id: str) -> str: | |
| return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"conv:{conversation_id}")) | |
| # ── Supabase client ─────────────────────────────────────────────────────────── | |
| def _get_sb(): | |
| global _sb_client | |
| if _sb_client is None and SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY: | |
| try: | |
| from supabase import create_client | |
| _sb_client = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY) | |
| except Exception: | |
| logger.exception("[Memory] Không khởi tạo được Supabase client.") | |
| return _sb_client | |
| # ── Dynamic limits ──────────────────────────────────────────────────────────── | |
| def _get_member_count(conversation_id: str) -> int: | |
| """Lấy số thành viên trong room từ Supabase. Fallback về _DEFAULT_N.""" | |
| if not conversation_id.startswith("room-"): | |
| return _DEFAULT_N | |
| sb = _get_sb() | |
| if sb is None: | |
| return _DEFAULT_N | |
| room_id = conversation_id.removeprefix("room-") | |
| try: | |
| res = ( | |
| sb.table("room_members") | |
| .select("user_id", count="exact") | |
| .eq("room_id", room_id) | |
| .execute() | |
| ) | |
| n = res.count or 0 | |
| return n if n > 0 else _DEFAULT_N | |
| except Exception: | |
| logger.exception("[Memory] Lỗi lấy số thành viên room '%s'", room_id) | |
| return _DEFAULT_N | |
| def _get_limits(conversation_id: str) -> tuple[int, int, int]: | |
| """Trả về (max_buffer, summarize_count, keep_recent) theo số thành viên n.""" | |
| n = _get_member_count(conversation_id) | |
| return 10 * n, 6 * n, 4 * n | |
| # ── Load / Save ─────────────────────────────────────────────────────────────── | |
| def load(conversation_id: str) -> tuple[str, list[dict]]: | |
| """Trả về (summary, buffer). Trả về ('', []) nếu chưa có.""" | |
| if not QDRANT_URL: | |
| return "", [] | |
| try: | |
| results = _get_qdrant().retrieve( | |
| collection_name=_COLLECTION, | |
| ids=[_point_id(conversation_id)], | |
| with_payload=True, | |
| ) | |
| if results: | |
| payload = results[0].payload or {} | |
| summary = payload.get("summary", "") or "" | |
| buffer = payload.get("buffer", []) or [] | |
| return summary, buffer | |
| return "", [] | |
| except Exception: | |
| logger.exception("[Memory] Lỗi load conversation_id='%s'", conversation_id) | |
| return "", [] | |
| def save(conversation_id: str, summary: str, buffer: list[dict]) -> None: | |
| if not QDRANT_URL: | |
| return | |
| try: | |
| _get_qdrant().upsert( | |
| collection_name=_COLLECTION, | |
| points=[PointStruct( | |
| id=_point_id(conversation_id), | |
| vector=_DUMMY_VECTOR, | |
| payload={ | |
| "conversation_id": conversation_id, | |
| "summary": summary, | |
| "buffer": buffer, | |
| }, | |
| )], | |
| ) | |
| except Exception: | |
| logger.exception("[Memory] Lỗi save conversation_id='%s'", conversation_id) | |
| # ── Summarization ───────────────────────────────────────────────────────────── | |
| def _summarize(existing_summary: str, messages: list[dict]) -> str: | |
| """Tóm tắt danh sách tin nhắn, kết hợp với summary hiện có.""" | |
| history_text = "\n".join( | |
| f"{'User' if m['role'] == 'user' else 'Assistant'}: {m['content']}" | |
| for m in messages | |
| ) | |
| prior_block = ( | |
| f"Tóm tắt trước đó:\n{existing_summary}\n\n" | |
| if existing_summary | |
| else "" | |
| ) | |
| system = SystemMessage(content=( | |
| "Bạn là trợ lý tóm tắt hội thoại. " | |
| "Tóm tắt ngắn gọn, giữ nguyên các thông tin quan trọng, " | |
| "sự kiện, tên gọi, và quyết định đã được đề cập. " | |
| "Chỉ trả về đoạn tóm tắt, không giải thích thêm." | |
| )) | |
| human = HumanMessage(content=( | |
| f"{prior_block}" | |
| f"Hội thoại cần tóm tắt:\n{history_text}\n\n" | |
| "Viết tóm tắt:" | |
| )) | |
| try: | |
| return llm.invoke([system, human]).content.strip() | |
| except Exception: | |
| logger.exception("[Memory] Lỗi khi summarize") | |
| return existing_summary | |
| # ── Public API ──────────────────────────────────────────────────────────────── | |
| def _now_vn() -> str: | |
| _VN_TZ = timezone(timedelta(hours=7)) | |
| return datetime.now(timezone.utc).astimezone(_VN_TZ).strftime("%d/%m/%Y %H:%M:%S") | |
| def add_turn(conversation_id: str, sender_id: str, user_msg: str, ai_msg: str) -> None: | |
| """Thêm 1 lượt user+assistant vào buffer, trigger summarize nếu cần.""" | |
| summary, buffer = load(conversation_id) | |
| max_buffer, summarize_count, keep_recent = _get_limits(conversation_id) | |
| ts = _now_vn() | |
| buffer.append({"role": "user", "content": f"[{ts}] {sender_id}: {user_msg}"}) | |
| buffer.append({"role": "assistant", "content": f"[{ts}] Assistant: {ai_msg}"}) | |
| if len(buffer) > max_buffer: | |
| to_summarize = buffer[:summarize_count] | |
| buffer = buffer[summarize_count:] | |
| if len(buffer) > keep_recent: | |
| buffer = buffer[-keep_recent:] | |
| logger.info( | |
| "[Memory] Buffer vượt %d, tóm tắt %d tin → giữ %d tin.", | |
| max_buffer, len(to_summarize), len(buffer), | |
| ) | |
| summary = _summarize(summary, to_summarize) | |
| save(conversation_id, summary, buffer) | |
| def seed_room(conversation_id: str, messages: list[dict]) -> None: | |
| """ | |
| Seed Qdrant buffer từ danh sách tin nhắn Redis thô. | |
| Mỗi message được chuyển thành role='user', content='[ts UTC+7] name: content'. | |
| Nếu vượt max_buffer thì tự động summarize trước khi lưu. | |
| """ | |
| _NAME_FIELDS = ["sender_username", "username", "u_username", "name", "u_name", | |
| "senderName", "displayName", "display_name", "fullName", "sender_id"] | |
| def _get_field(m: dict, fields: list[str]) -> str: | |
| for f in fields: | |
| v = m.get(f) | |
| if v and str(v).strip(): | |
| return str(v).strip() | |
| return "" | |
| buffer: list[dict] = [] | |
| for msg in messages: | |
| name = _get_field(msg, _NAME_FIELDS) or "unknown" | |
| content = msg.get("content") or msg.get("text") or msg.get("msg", "") | |
| ts = format_created_at(msg.get("created_at") or msg.get("timestamp", "")) | |
| buffer.append({"role": "user", "content": f"[{ts}] {name}: {content}"}) | |
| max_buffer, summarize_count, keep_recent = _get_limits(conversation_id) | |
| summary = "" | |
| while len(buffer) > max_buffer: | |
| to_summarize = buffer[:summarize_count] | |
| buffer = buffer[summarize_count:] | |
| if len(buffer) > keep_recent: | |
| buffer = buffer[-keep_recent:] | |
| logger.info( | |
| "[Memory] seed_room: tóm tắt %d tin → giữ %d tin còn lại", | |
| len(to_summarize), len(buffer), | |
| ) | |
| summary = _summarize(summary, to_summarize) | |
| save(conversation_id, summary, buffer) | |
| logger.info( | |
| "[Memory] seed_room '%s': lưu buffer=%d tin, summary=%s", | |
| conversation_id, len(buffer), "có" if summary else "không", | |
| ) | |
| def get_context(conversation_id: str) -> str: | |
| """Trả về chuỗi context (summary + buffer) để đưa vào prompt.""" | |
| summary, buffer = load(conversation_id) | |
| parts: list[str] = [] | |
| if summary: | |
| parts.append(f"[Tóm tắt lịch sử trước đó]\n{summary}") | |
| if buffer: | |
| recent_lines = "\n".join( | |
| f"{'User' if m['role'] == 'user' else 'Assistant'}: {m['content']}" | |
| for m in buffer | |
| ) | |
| parts.append(f"[Tin nhắn gần đây]\n{recent_lines}") | |
| return "\n\n".join(parts) if parts else "(Chưa có lịch sử trò chuyện)" | |