092_agent_api / conversation_memory.py
anhkhoiphan's picture
Format lại tin nhắn gửi lên buffer Qdrant
f462894
"""
ConversationSummaryBufferMemory — Qdrant-backed.
Giới hạn buffer tính theo số thành viên room (n):
MAX_BUFFER = 10n
SUMMARIZE_COUNT = 6n (số tin cũ nhất được tóm tắt khi vượt ngưỡng)
KEEP_RECENT = 4n (số tin giữ lại trong buffer sau khi tóm tắt)
Fallback n = 20 nếu không kết nối được Supabase hoặc không phải room.
"""
import logging
import uuid
from datetime import datetime, timezone, timedelta
from typing import Optional
from langchain_core.messages import HumanMessage, SystemMessage
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
from src.config import QDRANT_API_KEY, QDRANT_URL, SUPABASE_SERVICE_ROLE_KEY, SUPABASE_URL
from src.llm import llm
from src.tools.utils import format_created_at
logger = logging.getLogger(__name__)
_DEFAULT_N = 20
_COLLECTION = "conversation_memory"
_DUMMY_VECTOR = [0.0]
_qdrant_client: Optional[QdrantClient] = None
_sb_client = None
# ── Qdrant client ─────────────────────────────────────────────────────────────
def _get_qdrant() -> QdrantClient:
global _qdrant_client
if _qdrant_client is None:
_qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
_ensure_collection(_qdrant_client)
return _qdrant_client
def _ensure_collection(client: QdrantClient) -> None:
existing = {c.name for c in client.get_collections().collections}
if _COLLECTION not in existing:
client.create_collection(
collection_name=_COLLECTION,
vectors_config=VectorParams(size=1, distance=Distance.DOT),
)
logger.info("Qdrant: collection '%s' created.", _COLLECTION)
def _point_id(conversation_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_DNS, f"conv:{conversation_id}"))
# ── Supabase client ───────────────────────────────────────────────────────────
def _get_sb():
global _sb_client
if _sb_client is None and SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY:
try:
from supabase import create_client
_sb_client = create_client(SUPABASE_URL, SUPABASE_SERVICE_ROLE_KEY)
except Exception:
logger.exception("[Memory] Không khởi tạo được Supabase client.")
return _sb_client
# ── Dynamic limits ────────────────────────────────────────────────────────────
def _get_member_count(conversation_id: str) -> int:
"""Lấy số thành viên trong room từ Supabase. Fallback về _DEFAULT_N."""
if not conversation_id.startswith("room-"):
return _DEFAULT_N
sb = _get_sb()
if sb is None:
return _DEFAULT_N
room_id = conversation_id.removeprefix("room-")
try:
res = (
sb.table("room_members")
.select("user_id", count="exact")
.eq("room_id", room_id)
.execute()
)
n = res.count or 0
return n if n > 0 else _DEFAULT_N
except Exception:
logger.exception("[Memory] Lỗi lấy số thành viên room '%s'", room_id)
return _DEFAULT_N
def _get_limits(conversation_id: str) -> tuple[int, int, int]:
"""Trả về (max_buffer, summarize_count, keep_recent) theo số thành viên n."""
n = _get_member_count(conversation_id)
return 10 * n, 6 * n, 4 * n
# ── Load / Save ───────────────────────────────────────────────────────────────
def load(conversation_id: str) -> tuple[str, list[dict]]:
"""Trả về (summary, buffer). Trả về ('', []) nếu chưa có."""
if not QDRANT_URL:
return "", []
try:
results = _get_qdrant().retrieve(
collection_name=_COLLECTION,
ids=[_point_id(conversation_id)],
with_payload=True,
)
if results:
payload = results[0].payload or {}
summary = payload.get("summary", "") or ""
buffer = payload.get("buffer", []) or []
return summary, buffer
return "", []
except Exception:
logger.exception("[Memory] Lỗi load conversation_id='%s'", conversation_id)
return "", []
def save(conversation_id: str, summary: str, buffer: list[dict]) -> None:
if not QDRANT_URL:
return
try:
_get_qdrant().upsert(
collection_name=_COLLECTION,
points=[PointStruct(
id=_point_id(conversation_id),
vector=_DUMMY_VECTOR,
payload={
"conversation_id": conversation_id,
"summary": summary,
"buffer": buffer,
},
)],
)
except Exception:
logger.exception("[Memory] Lỗi save conversation_id='%s'", conversation_id)
# ── Summarization ─────────────────────────────────────────────────────────────
def _summarize(existing_summary: str, messages: list[dict]) -> str:
"""Tóm tắt danh sách tin nhắn, kết hợp với summary hiện có."""
history_text = "\n".join(
f"{'User' if m['role'] == 'user' else 'Assistant'}: {m['content']}"
for m in messages
)
prior_block = (
f"Tóm tắt trước đó:\n{existing_summary}\n\n"
if existing_summary
else ""
)
system = SystemMessage(content=(
"Bạn là trợ lý tóm tắt hội thoại. "
"Tóm tắt ngắn gọn, giữ nguyên các thông tin quan trọng, "
"sự kiện, tên gọi, và quyết định đã được đề cập. "
"Chỉ trả về đoạn tóm tắt, không giải thích thêm."
))
human = HumanMessage(content=(
f"{prior_block}"
f"Hội thoại cần tóm tắt:\n{history_text}\n\n"
"Viết tóm tắt:"
))
try:
return llm.invoke([system, human]).content.strip()
except Exception:
logger.exception("[Memory] Lỗi khi summarize")
return existing_summary
# ── Public API ────────────────────────────────────────────────────────────────
def _now_vn() -> str:
_VN_TZ = timezone(timedelta(hours=7))
return datetime.now(timezone.utc).astimezone(_VN_TZ).strftime("%d/%m/%Y %H:%M:%S")
def add_turn(conversation_id: str, sender_id: str, user_msg: str, ai_msg: str) -> None:
"""Thêm 1 lượt user+assistant vào buffer, trigger summarize nếu cần."""
summary, buffer = load(conversation_id)
max_buffer, summarize_count, keep_recent = _get_limits(conversation_id)
ts = _now_vn()
buffer.append({"role": "user", "content": f"[{ts}] {sender_id}: {user_msg}"})
buffer.append({"role": "assistant", "content": f"[{ts}] Assistant: {ai_msg}"})
if len(buffer) > max_buffer:
to_summarize = buffer[:summarize_count]
buffer = buffer[summarize_count:]
if len(buffer) > keep_recent:
buffer = buffer[-keep_recent:]
logger.info(
"[Memory] Buffer vượt %d, tóm tắt %d tin → giữ %d tin.",
max_buffer, len(to_summarize), len(buffer),
)
summary = _summarize(summary, to_summarize)
save(conversation_id, summary, buffer)
def seed_room(conversation_id: str, messages: list[dict]) -> None:
"""
Seed Qdrant buffer từ danh sách tin nhắn Redis thô.
Mỗi message được chuyển thành role='user', content='[ts UTC+7] name: content'.
Nếu vượt max_buffer thì tự động summarize trước khi lưu.
"""
_NAME_FIELDS = ["sender_username", "username", "u_username", "name", "u_name",
"senderName", "displayName", "display_name", "fullName", "sender_id"]
def _get_field(m: dict, fields: list[str]) -> str:
for f in fields:
v = m.get(f)
if v and str(v).strip():
return str(v).strip()
return ""
buffer: list[dict] = []
for msg in messages:
name = _get_field(msg, _NAME_FIELDS) or "unknown"
content = msg.get("content") or msg.get("text") or msg.get("msg", "")
ts = format_created_at(msg.get("created_at") or msg.get("timestamp", ""))
buffer.append({"role": "user", "content": f"[{ts}] {name}: {content}"})
max_buffer, summarize_count, keep_recent = _get_limits(conversation_id)
summary = ""
while len(buffer) > max_buffer:
to_summarize = buffer[:summarize_count]
buffer = buffer[summarize_count:]
if len(buffer) > keep_recent:
buffer = buffer[-keep_recent:]
logger.info(
"[Memory] seed_room: tóm tắt %d tin → giữ %d tin còn lại",
len(to_summarize), len(buffer),
)
summary = _summarize(summary, to_summarize)
save(conversation_id, summary, buffer)
logger.info(
"[Memory] seed_room '%s': lưu buffer=%d tin, summary=%s",
conversation_id, len(buffer), "có" if summary else "không",
)
def get_context(conversation_id: str) -> str:
"""Trả về chuỗi context (summary + buffer) để đưa vào prompt."""
summary, buffer = load(conversation_id)
parts: list[str] = []
if summary:
parts.append(f"[Tóm tắt lịch sử trước đó]\n{summary}")
if buffer:
recent_lines = "\n".join(
f"{'User' if m['role'] == 'user' else 'Assistant'}: {m['content']}"
for m in buffer
)
parts.append(f"[Tin nhắn gần đây]\n{recent_lines}")
return "\n\n".join(parts) if parts else "(Chưa có lịch sử trò chuyện)"