anhkhoiphan commited on
Commit
dd53ab9
·
1 Parent(s): a5c91ab

Thay đổi logic tool RAG

Browse files
Files changed (3) hide show
  1. pdf_rag.py +172 -64
  2. prompts.py +9 -0
  3. tools/chat_tools.py +3 -0
pdf_rag.py CHANGED
@@ -1,26 +1,31 @@
1
  """
2
- PDF RAG — chunk, embed (OpenAI), index to Qdrant, hybrid search.
3
 
4
- Hybrid search = dense (semantic, OpenAI embeddings)
5
- + sparse/keyword (Qdrant full-text index)
6
- merged via Reciprocal Rank Fusion (RRF).
 
 
 
 
7
  """
8
 
9
  import logging
10
  import uuid
11
  from typing import Optional
12
 
 
13
  from openai import OpenAI
14
  from qdrant_client import QdrantClient
15
  from qdrant_client.models import (
16
  Distance,
17
  FieldCondition,
18
  Filter,
19
- MatchText,
20
  MatchValue,
21
  PointStruct,
22
- TextIndexParams,
23
- TokenizerType,
24
  VectorParams,
25
  )
26
 
@@ -29,21 +34,25 @@ from src.pdf_processing import pdf_to_markdown
29
 
30
  logger = logging.getLogger(__name__)
31
 
32
- _PDF_COLLECTION = "pdf_chunks"
33
- _EMBED_MODEL = "text-embedding-3-small"
34
- _EMBED_DIMS = 1536
 
 
35
 
36
- _CHUNK_SIZE = 800 # ký tự / chunk
37
- _CHUNK_OVERLAP = 150 # tự overlap giữa các chunk
38
- _EMBED_BATCH = 32 # số chunk embed song song mỗi lần
39
 
40
- _RRF_K = 60 # hằng số RRF (60 là giá trị chuẩn trong tài liệu)
 
41
 
42
- _qdrant: Optional[QdrantClient] = None
43
- _openai: Optional[OpenAI] = None
 
44
 
45
 
46
- # ── Client helpers ────────────────────────────────────────────────────────────
47
 
48
  def _get_qdrant() -> QdrantClient:
49
  global _qdrant
@@ -64,31 +73,44 @@ def _get_openai() -> OpenAI:
64
  return _openai
65
 
66
 
 
 
 
 
 
 
 
67
  def _ensure_collection(client: QdrantClient) -> None:
68
  existing = {c.name for c in client.get_collections().collections}
69
  if _PDF_COLLECTION not in existing:
70
  client.create_collection(
71
  collection_name=_PDF_COLLECTION,
72
- vectors_config=VectorParams(size=_EMBED_DIMS, distance=Distance.COSINE),
 
 
 
 
 
73
  )
74
- # Full-text index cho keyword search
 
 
 
75
  client.create_payload_index(
76
  collection_name=_PDF_COLLECTION,
77
- field_name="chunk_text",
78
- field_schema=TextIndexParams(
79
- type="text",
80
- tokenizer=TokenizerType.MULTILINGUAL,
81
- ),
82
  )
83
- logger.info("Qdrant: collection '%s' created.", _PDF_COLLECTION)
 
 
 
 
84
 
85
 
86
  # ── Chunking ──────────────────────────────────────────────────────────────────
87
 
88
  def _chunk_text(text: str) -> list[str]:
89
- """
90
- Chia text thành các chunk có overlap, ưu tiên cắt tại ranh giới câu.
91
- """
92
  if len(text) <= _CHUNK_SIZE:
93
  return [text.strip()] if text.strip() else []
94
 
@@ -98,7 +120,6 @@ def _chunk_text(text: str) -> list[str]:
98
  end = min(start + _CHUNK_SIZE, len(text))
99
 
100
  if end < len(text):
101
- # Tìm ranh giới câu gần nhất để cắt gọn
102
  for boundary in ('\n\n', '\n', '.', '!', '?'):
103
  pos = text.rfind(boundary, start + _CHUNK_SIZE // 2, end)
104
  if pos != -1:
@@ -109,9 +130,10 @@ def _chunk_text(text: str) -> list[str]:
109
  if chunk:
110
  chunks.append(chunk)
111
 
 
112
  next_start = end - _CHUNK_OVERLAP
113
  if next_start <= start:
114
- next_start = end # tránh vòng lặp vô tận
115
  start = next_start
116
 
117
  return chunks
@@ -128,12 +150,91 @@ def _embed_one(text: str) -> list[float]:
128
  return _embed_batch([text])[0]
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  # ── Public API ────────────────────────────────────────────────────────────────
132
 
133
  def index_pdf(pdf_path: str, pdf_name: str, conversation_id: str) -> int:
134
  """
135
- Đọc PDF, chunk, embed và index lên Qdrant.
136
- Dùng UUID v5 làm point ID để upsert idempotent (re-send cùng file không tạo duplicate).
137
 
138
  Returns:
139
  Số chunk đã index.
@@ -148,8 +249,9 @@ def index_pdf(pdf_path: str, pdf_name: str, conversation_id: str) -> int:
148
  indexed = 0
149
 
150
  for batch_start in range(0, len(chunks), _EMBED_BATCH):
151
- batch = chunks[batch_start : batch_start + _EMBED_BATCH]
152
- vectors = _embed_batch(batch)
 
153
 
154
  points = [
155
  PointStruct(
@@ -157,7 +259,10 @@ def index_pdf(pdf_path: str, pdf_name: str, conversation_id: str) -> int:
157
  uuid.NAMESPACE_DNS,
158
  f"{conversation_id}::{pdf_name}::{batch_start + i}",
159
  )),
160
- vector=vectors[i],
 
 
 
161
  payload={
162
  "conversation_id": conversation_id,
163
  "pdf_name": pdf_name,
@@ -180,15 +285,13 @@ def index_pdf(pdf_path: str, pdf_name: str, conversation_id: str) -> int:
180
 
181
  def hybrid_search(query: str, conversation_id: str, top_k: int = 5) -> list[str]:
182
  """
183
- Hybrid search kết hợp:
184
- - Dense: semantic search bằng OpenAI embedding
185
- - Sparse: full-text keyword search (Qdrant TextIndex)
186
- Merge bằng Reciprocal Rank Fusion (RRF).
187
-
188
- Tìm trên TOÀN BỘ PDF đã index cho conversation_id này.
189
 
190
- Returns:
191
- Danh sách chunk text liên quan nhất, sắp xếp theo RRF score.
192
  """
193
  client = _get_qdrant()
194
  conv_filter = Filter(must=[
@@ -196,42 +299,47 @@ def hybrid_search(query: str, conversation_id: str, top_k: int = 5) -> list[str]
196
  ])
197
 
198
  # ── Dense search ──────────────────────────────────────────────────────────
199
- query_vec = _embed_one(query)
200
- dense_hits = client.search(
201
  collection_name=_PDF_COLLECTION,
202
- query_vector=query_vec,
 
203
  query_filter=conv_filter,
204
  limit=top_k * 3,
205
  with_payload=True,
206
- )
207
 
208
- # ── Keyword / full-text search ────────────────────────────────────────────
209
- kw_filter = Filter(must=[
210
- FieldCondition(key="conversation_id", match=MatchValue(value=conversation_id)),
211
- FieldCondition(key="chunk_text", match=MatchText(text=query)),
212
- ])
213
- kw_hits, _ = client.scroll(
214
  collection_name=_PDF_COLLECTION,
215
- scroll_filter=kw_filter,
 
 
216
  limit=top_k * 3,
217
  with_payload=True,
218
- with_vectors=False,
219
- )
220
 
221
  # ── RRF merge ─────────────────────────────────────────────────────────────
222
- scores: dict[str, float] = {}
223
- texts: dict[str, str] = {}
224
 
225
  for rank, hit in enumerate(dense_hits):
226
  sid = str(hit.id)
227
- scores[sid] = scores.get(sid, 0.0) + 1.0 / (rank + _RRF_K)
228
- texts[sid] = hit.payload.get("chunk_text", "")
229
 
230
- for rank, hit in enumerate(kw_hits):
231
  sid = str(hit.id)
232
  scores[sid] = scores.get(sid, 0.0) + 1.0 / (rank + _RRF_K)
233
- if sid not in texts:
234
- texts[sid] = hit.payload.get("chunk_text", "")
235
 
236
  top_ids = sorted(scores, key=scores.__getitem__, reverse=True)[:top_k]
237
- return [texts[sid] for sid in top_ids if sid in texts]
 
 
 
 
 
 
 
 
1
  """
2
+ PDF RAG — chunk, embed, index to Qdrant, hybrid search.
3
 
4
+ Hybrid search = dense (OpenAI text-embedding-3-small, cosine similarity)
5
+ + sparse (BM25 via fastembed Qdrant/bm25, dot product)
6
+ merged via Reciprocal Rank Fusion (RRF, k=60).
7
+
8
+ Sau RRF, mỗi chunk được mở rộng sang các chunk lân cận (N-3 đến N+3) trong cùng
9
+ PDF để đưa vào context đầy đủ hơn. Không dùng overlap vì neighbor expansion
10
+ đã đảm bảo không mất context tại ranh giới chunk.
11
  """
12
 
13
  import logging
14
  import uuid
15
  from typing import Optional
16
 
17
+ from fastembed import SparseTextEmbedding
18
  from openai import OpenAI
19
  from qdrant_client import QdrantClient
20
  from qdrant_client.models import (
21
  Distance,
22
  FieldCondition,
23
  Filter,
24
+ MatchAny,
25
  MatchValue,
26
  PointStruct,
27
+ SparseVector,
28
+ SparseVectorParams,
29
  VectorParams,
30
  )
31
 
 
34
 
35
  logger = logging.getLogger(__name__)
36
 
37
+ # Collection v2: named vectors (dense + sparse). Xóa collection cũ "pdf_chunks" nếu còn.
38
+ _PDF_COLLECTION = "pdf_chunks_v2"
39
+ _EMBED_MODEL = "text-embedding-3-small"
40
+ _EMBED_DIMS = 1536
41
+ _BM25_MODEL = "Qdrant/bm25"
42
 
43
+ _CHUNK_SIZE = 1000 # ký tự / chunk
44
+ _CHUNK_OVERLAP = 0 # không cần overlap neighbor expansion xử lý ranh giới
45
+ _EMBED_BATCH = 32 # số chunk embed song song mỗi lần
46
 
47
+ _RRF_K = 60 # hằng số RRF chuẩn
48
+ _NEIGHBOR_WINDOW = 3 # fetch N-3 đến N+3 quanh mỗi chunk được retrieve
49
 
50
+ _qdrant: Optional[QdrantClient] = None
51
+ _openai: Optional[OpenAI] = None
52
+ _bm25: Optional[SparseTextEmbedding] = None
53
 
54
 
55
+ # ── Client / model helpers ────────────────────────────────────────────────────
56
 
57
  def _get_qdrant() -> QdrantClient:
58
  global _qdrant
 
73
  return _openai
74
 
75
 
76
+ def _get_bm25() -> SparseTextEmbedding:
77
+ global _bm25
78
+ if _bm25 is None:
79
+ _bm25 = SparseTextEmbedding(model_name=_BM25_MODEL)
80
+ return _bm25
81
+
82
+
83
  def _ensure_collection(client: QdrantClient) -> None:
84
  existing = {c.name for c in client.get_collections().collections}
85
  if _PDF_COLLECTION not in existing:
86
  client.create_collection(
87
  collection_name=_PDF_COLLECTION,
88
+ vectors_config={
89
+ "dense": VectorParams(size=_EMBED_DIMS, distance=Distance.COSINE),
90
+ },
91
+ sparse_vectors_config={
92
+ "sparse": SparseVectorParams(),
93
+ },
94
  )
95
+ logger.info("Qdrant: collection '%s' created.", _PDF_COLLECTION)
96
+
97
+ # Payload indexes — idempotent, an toàn gọi mỗi lần khởi động.
98
+ for field in ("conversation_id", "pdf_name"):
99
  client.create_payload_index(
100
  collection_name=_PDF_COLLECTION,
101
+ field_name=field,
102
+ field_schema="keyword",
 
 
 
103
  )
104
+ client.create_payload_index(
105
+ collection_name=_PDF_COLLECTION,
106
+ field_name="chunk_index",
107
+ field_schema="integer",
108
+ )
109
 
110
 
111
  # ── Chunking ──────────────────────────────────────────────────────────────────
112
 
113
  def _chunk_text(text: str) -> list[str]:
 
 
 
114
  if len(text) <= _CHUNK_SIZE:
115
  return [text.strip()] if text.strip() else []
116
 
 
120
  end = min(start + _CHUNK_SIZE, len(text))
121
 
122
  if end < len(text):
 
123
  for boundary in ('\n\n', '\n', '.', '!', '?'):
124
  pos = text.rfind(boundary, start + _CHUNK_SIZE // 2, end)
125
  if pos != -1:
 
130
  if chunk:
131
  chunks.append(chunk)
132
 
133
+ # _CHUNK_OVERLAP = 0, nhưng giữ công thức chung để dễ điều chỉnh sau
134
  next_start = end - _CHUNK_OVERLAP
135
  if next_start <= start:
136
+ next_start = end
137
  start = next_start
138
 
139
  return chunks
 
150
  return _embed_batch([text])[0]
151
 
152
 
153
+ def _bm25_batch(texts: list[str]) -> list[SparseVector]:
154
+ embeddings = list(_get_bm25().embed(texts))
155
+ return [
156
+ SparseVector(indices=e.indices.tolist(), values=e.values.tolist())
157
+ for e in embeddings
158
+ ]
159
+
160
+
161
+ def _bm25_one(text: str) -> SparseVector:
162
+ return _bm25_batch([text])[0]
163
+
164
+
165
+ # ── Neighbor expansion ────────────────────────────────────────────────────────
166
+
167
+ def _expand_chunks(
168
+ client: QdrantClient,
169
+ conversation_id: str,
170
+ hits: list[tuple[str, int]], # (pdf_name, chunk_index)
171
+ window: int = _NEIGHBOR_WINDOW,
172
+ ) -> list[str]:
173
+ """
174
+ Với mỗi (pdf_name, chunk_index) được retrieve, fetch thêm chunk N-window đến N+window
175
+ từ cùng PDF. Các cửa sổ chồng lấp được merge thành một đoạn liên tục để
176
+ tránh đưa nội dung trùng lặp vào context.
177
+
178
+ Returns:
179
+ Danh sách đoạn văn bản, mỗi đoạn là một cửa sổ liên tục (đã merge nếu chồng lấp).
180
+ """
181
+ # Gom tất cả chunk_index cần fetch theo từng pdf_name
182
+ pdf_needed: dict[str, set[int]] = {}
183
+ for pdf_name, chunk_index in hits:
184
+ indices = set(range(max(0, chunk_index - window), chunk_index + window + 1))
185
+ pdf_needed.setdefault(pdf_name, set()).update(indices)
186
+
187
+ results: list[str] = []
188
+
189
+ for pdf_name, needed in pdf_needed.items():
190
+ fetch_filter = Filter(must=[
191
+ FieldCondition(key="conversation_id", match=MatchValue(value=conversation_id)),
192
+ FieldCondition(key="pdf_name", match=MatchValue(value=pdf_name)),
193
+ FieldCondition(key="chunk_index", match=MatchAny(any=sorted(needed))),
194
+ ])
195
+ fetched, _ = client.scroll(
196
+ collection_name=_PDF_COLLECTION,
197
+ scroll_filter=fetch_filter,
198
+ limit=len(needed) + 5,
199
+ with_payload=True,
200
+ with_vectors=False,
201
+ )
202
+
203
+ # Map chunk_index → text, rồi sort
204
+ chunk_map = {
205
+ p.payload["chunk_index"]: p.payload.get("chunk_text", "")
206
+ for p in fetched
207
+ if "chunk_index" in p.payload
208
+ }
209
+ sorted_indices = sorted(chunk_map)
210
+ if not sorted_indices:
211
+ continue
212
+
213
+ # Gom các chunk_index liên tiếp thành từng run (merge overlapping windows)
214
+ runs: list[list[int]] = []
215
+ current: list[int] = [sorted_indices[0]]
216
+ for idx in sorted_indices[1:]:
217
+ if idx == current[-1] + 1:
218
+ current.append(idx)
219
+ else:
220
+ runs.append(current)
221
+ current = [idx]
222
+ runs.append(current)
223
+
224
+ for run in runs:
225
+ text = "\n\n".join(chunk_map[i] for i in run)
226
+ if text.strip():
227
+ results.append(text)
228
+
229
+ return results
230
+
231
+
232
  # ── Public API ────────────────────────────────────────────────────────────────
233
 
234
  def index_pdf(pdf_path: str, pdf_name: str, conversation_id: str) -> int:
235
  """
236
+ Đọc PDF, chunk, embed (dense + sparse) upsert vào Qdrant.
237
+ UUID v5 làm point ID đảm bảo idempotent gửi lại cùng file không tạo duplicate.
238
 
239
  Returns:
240
  Số chunk đã index.
 
249
  indexed = 0
250
 
251
  for batch_start in range(0, len(chunks), _EMBED_BATCH):
252
+ batch = chunks[batch_start : batch_start + _EMBED_BATCH]
253
+ dense_vecs = _embed_batch(batch)
254
+ sparse_vecs = _bm25_batch(batch)
255
 
256
  points = [
257
  PointStruct(
 
259
  uuid.NAMESPACE_DNS,
260
  f"{conversation_id}::{pdf_name}::{batch_start + i}",
261
  )),
262
+ vector={
263
+ "dense": dense_vecs[i],
264
+ "sparse": sparse_vecs[i],
265
+ },
266
  payload={
267
  "conversation_id": conversation_id,
268
  "pdf_name": pdf_name,
 
285
 
286
  def hybrid_search(query: str, conversation_id: str, top_k: int = 5) -> list[str]:
287
  """
288
+ Hybrid search:
289
+ Dense OpenAI cosine similarity, Qdrant trả cosine score.
290
+ Sparse BM25 dot product, Qdrant trả BM25 score.
291
+ Merge RRF: score = 1/(k + rank_dense) + 1/(k + rank_sparse).
 
 
292
 
293
+ Sau RRF, mỗi chunk được mở rộng sang N-3 đến N+3 trong cùng PDF.
294
+ Các cửa sổ chồng lấp tự động được merge thành đoạn liên tục.
295
  """
296
  client = _get_qdrant()
297
  conv_filter = Filter(must=[
 
299
  ])
300
 
301
  # ── Dense search ──────────────────────────────────────────────────────────
302
+ dense_hits = client.query_points(
 
303
  collection_name=_PDF_COLLECTION,
304
+ query=_embed_one(query),
305
+ using="dense",
306
  query_filter=conv_filter,
307
  limit=top_k * 3,
308
  with_payload=True,
309
+ ).points
310
 
311
+ # ── Sparse (BM25) search ──────────────────────────────────────────────────
312
+ bm25_vec = _bm25_one(query)
313
+ sparse_hits = client.query_points(
 
 
 
314
  collection_name=_PDF_COLLECTION,
315
+ query=SparseVector(indices=bm25_vec.indices, values=bm25_vec.values),
316
+ using="sparse",
317
+ query_filter=conv_filter,
318
  limit=top_k * 3,
319
  with_payload=True,
320
+ ).points
 
321
 
322
  # ── RRF merge ─────────────────────────────────────────────────────────────
323
+ scores: dict[str, float] = {}
324
+ payloads: dict[str, dict] = {}
325
 
326
  for rank, hit in enumerate(dense_hits):
327
  sid = str(hit.id)
328
+ scores[sid] = scores.get(sid, 0.0) + 1.0 / (rank + _RRF_K)
329
+ payloads[sid] = hit.payload
330
 
331
+ for rank, hit in enumerate(sparse_hits):
332
  sid = str(hit.id)
333
  scores[sid] = scores.get(sid, 0.0) + 1.0 / (rank + _RRF_K)
334
+ if sid not in payloads:
335
+ payloads[sid] = hit.payload
336
 
337
  top_ids = sorted(scores, key=scores.__getitem__, reverse=True)[:top_k]
338
+
339
+ # ── Neighbor expansion ────────────────────────────────────────────────────
340
+ hits_meta = [
341
+ (payloads[sid].get("pdf_name", ""), payloads[sid].get("chunk_index", 0))
342
+ for sid in top_ids
343
+ if sid in payloads
344
+ ]
345
+ return _expand_chunks(client, conversation_id, hits_meta)
prompts.py CHANGED
@@ -179,6 +179,14 @@ Nhiệm vụ: phân tích yêu cầu và gọi đúng công cụ để xử lý.
179
  read_link(url)
180
  → Đọc và trích xuất nội dung từ đường link URL.
181
 
 
 
 
 
 
 
 
 
182
  ═══ CHIẾN LƯỢC GỌI TOOL ═══
183
 
184
  Hỏi về nội dung đã thảo luận / tóm tắt → summarize_chat
@@ -189,6 +197,7 @@ Nhiệm vụ: phân tích yêu cầu và gọi đúng công cụ để xử lý.
189
  Muốn tra cứu thông tin đã lưu → get_memories
190
  Cần đọc nội dung từ link → read_link
191
  Muốn đặt nhắc nhở → add_reminder
 
192
 
193
  ═══ QUY TẮC BẮT BUỘC ═══
194
 
 
179
  read_link(url)
180
  → Đọc và trích xuất nội dung từ đường link URL.
181
 
182
+ 📚 KNOWLEDGE BASE (Tài liệu lớp học):
183
+ rag_search(query, conversation_id)
184
+ → Tìm kiếm trong tài liệu PDF đã được index cho conversation/room này
185
+ (lecture notes, slides, handouts, đề cương...).
186
+ Dùng khi người dùng hỏi về nội dung bài học, tài liệu đã chia sẻ,
187
+ hoặc kiến thức chuyên ngành liên quan đến lớp học.
188
+ Luôn truyền conversation_id từ thông tin được cung cấp.
189
+
190
  ═══ CHIẾN LƯỢC GỌI TOOL ═══
191
 
192
  Hỏi về nội dung đã thảo luận / tóm tắt → summarize_chat
 
197
  Muốn tra cứu thông tin đã lưu → get_memories
198
  Cần đọc nội dung từ link → read_link
199
  Muốn đặt nhắc nhở → add_reminder
200
+ Hỏi về nội dung bài học / tài liệu lớp → rag_search
201
 
202
  ═══ QUY TẮC BẮT BUỘC ═══
203
 
tools/chat_tools.py CHANGED
@@ -8,6 +8,7 @@ from . import memory as _memory_mod # noqa: F401
8
  from . import scheduler as _scheduler_mod # noqa: F401
9
  from . import summarizer as _summarizer_mod # noqa: F401
10
  from . import chart as _chart_mod # noqa: F401
 
11
 
12
  from .base import TOOLS as _REGISTRY, get_langchain_tools
13
 
@@ -22,6 +23,8 @@ _ALLOWED = {
22
  "save_memory", "get_memories",
23
  # Web
24
  "read_link",
 
 
25
  }
26
 
27
  TOOLS = [t for t in get_langchain_tools() if t.name in _ALLOWED]
 
8
  from . import scheduler as _scheduler_mod # noqa: F401
9
  from . import summarizer as _summarizer_mod # noqa: F401
10
  from . import chart as _chart_mod # noqa: F401
11
+ from . import rag as _rag_mod # noqa: F401
12
 
13
  from .base import TOOLS as _REGISTRY, get_langchain_tools
14
 
 
23
  "save_memory", "get_memories",
24
  # Web
25
  "read_link",
26
+ # Knowledge base (PDF RAG)
27
+ "rag_search",
28
  }
29
 
30
  TOOLS = [t for t in get_langchain_tools() if t.name in _ALLOWED]