092_agent_api / qdrant_store.py
anhkhoiphan's picture
Thêm tính năng custom prompt
b784540
import logging
import uuid
from typing import Optional
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, PointStruct, VectorParams
from src.config import QDRANT_API_KEY, QDRANT_URL
logger = logging.getLogger(__name__)
_COLLECTION = "user_prompts"
_DUMMY_VECTOR = [0.0]
_client: Optional[QdrantClient] = None
def _get_client() -> QdrantClient:
global _client
if _client is None:
_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
_ensure_collection(_client)
return _client
def _ensure_collection(client: QdrantClient) -> None:
existing = {c.name for c in client.get_collections().collections}
if _COLLECTION not in existing:
client.create_collection(
collection_name=_COLLECTION,
vectors_config=VectorParams(size=1, distance=Distance.DOT),
)
logger.info("Qdrant: collection '%s' created.", _COLLECTION)
def _point_id(user_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_DNS, user_id))
def save_custom_prompt(user_id: str, prompt: str) -> bool:
if not QDRANT_URL:
logger.warning("QDRANT_URL chưa được cấu hình.")
return False
try:
_get_client().upsert(
collection_name=_COLLECTION,
points=[PointStruct(
id=_point_id(user_id),
vector=_DUMMY_VECTOR,
payload={"user_id": user_id, "prompt": prompt},
)],
)
logger.info("Đã lưu custom prompt cho user '%s'.", user_id)
return True
except Exception as e:
logger.error("Lỗi lưu custom prompt cho '%s': %s", user_id, e)
return False
def get_custom_prompt(user_id: str) -> Optional[str]:
if not QDRANT_URL:
return None
try:
results = _get_client().retrieve(
collection_name=_COLLECTION,
ids=[_point_id(user_id)],
with_payload=True,
)
if results:
return results[0].payload.get("prompt")
return None
except Exception as e:
logger.error("Lỗi lấy custom prompt cho '%s': %s", user_id, e)
return None