Skip to content

Commit cb713ea

Browse files
authored
Merge pull request #66 from CyberCRI/Fix/qdrant-collection-selection
Fix/qdrant collection selection
2 parents 4b848cd + 8389968 commit cb713ea

File tree

3 files changed

+46
-24
lines changed

3 files changed

+46
-24
lines changed

tests/qdrant_syncronizer/test_qdrant_handler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,28 @@ def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_co
134134
"collection_welearn_mul_mulembmodel": {doc_id1},
135135
}
136136
self.assertDictEqual(dict(collections_names), expected)
137+
138+
def test_should_handle_multiple_slices_for_same_collection_with_multi_lingual_collection_and_gibberish(
139+
self,
140+
):
141+
self.client.create_collection(
142+
collection_name="collection_welearn_mul_mulembmodel_og",
143+
vectors_config=models.VectorParams(
144+
size=50, distance=models.Distance.COSINE
145+
),
146+
)
147+
148+
doc_id0 = uuid.uuid4()
149+
doc_id1 = uuid.uuid4()
150+
qdrant_connector = self.client
151+
fake_slice0 = FakeSlice(doc_id0, embedding_model_name="english-embmodel")
152+
fake_slice1 = FakeSlice(doc_id0, embedding_model_name="english-embmodel")
153+
154+
fake_slice1.order_sequence = 1
155+
156+
fake_slice2 = FakeSlice(doc_id1, embedding_model_name="mulembmodel")
157+
fake_slice2.document.lang = "pt"
158+
159+
slices = [fake_slice0, fake_slice1, fake_slice2]
160+
collections_names = classify_documents_per_collection(qdrant_connector, slices)
161+
self.assertNotIn("collection_welearn_mul_mulembmodel_og", collections_names)

welearn_datastack/modules/qdrant_handler.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from collections import defaultdict
32
from typing import Collection, Dict, List, Set, Type
43
from uuid import UUID
54

@@ -9,9 +8,7 @@
98
from qdrant_client.http.models import models
109

1110
from welearn_datastack.data.db_models import DocumentSlice
12-
from welearn_datastack.exceptions import (
13-
ErrorWhileDeletingChunks,
14-
)
11+
from welearn_datastack.exceptions import ErrorWhileDeletingChunks
1512

1613
logger = logging.getLogger(__name__)
1714

@@ -31,30 +28,30 @@ def classify_documents_per_collection(
3128
"""
3229
tmp_collections_names_in_qdrant = qdrant_connector.get_collections().collections
3330
collections_names_in_qdrant = [c.name for c in tmp_collections_names_in_qdrant]
34-
model_name_collection_name = {}
35-
for x in collections_names_in_qdrant:
36-
parts = x.split("_")
37-
if len(parts) >= 4:
38-
model_name_collection_name[parts[3]] = x
39-
else:
40-
logger.warning(
41-
"Collection name '%s' does not follow the expected format", x
42-
)
4331

44-
ret: Dict[str, Set[UUID]] = defaultdict(set)
32+
ret: Dict[str, Set[UUID]] = {}
4533
for dslice in slices:
46-
model_name = dslice.embedding_model.title
47-
try:
48-
collection_name = model_name_collection_name[model_name]
49-
ret[collection_name].add(dslice.document_id) # type: ignore
50-
except KeyError:
51-
logger.warning(
52-
"No collection found for model %s, document %s",
53-
model_name,
54-
dslice.document_id,
34+
lang = dslice.document.lang
35+
model = dslice.embedding_model.title
36+
collection_name = None
37+
multilingual_collection = f"collection_welearn_mul_{model}"
38+
mono_collection = f"collection_welearn_{lang}_{model}"
39+
40+
# Check multilingual or mono lingual
41+
if multilingual_collection in collections_names_in_qdrant:
42+
collection_name = multilingual_collection
43+
elif mono_collection in collections_names_in_qdrant:
44+
collection_name = mono_collection
45+
else:
46+
logger.error(
47+
f"Collection {collection_name} not found in Qdrant, slice {dslice.id} ignored",
5548
)
5649
continue
5750

51+
if collection_name not in ret:
52+
ret[collection_name] = set()
53+
ret[collection_name].add(dslice.document_id) # type: ignore
54+
5855
return ret
5956

6057

@@ -73,7 +70,6 @@ def delete_points_related_to_document(
7370
"""
7471
logger.info("Deletion started")
7572
logger.debug(f"Deleting points related to {documents_ids} in {collection_name}")
76-
op_res = None
7773

7874
try:
7975
op_res = qdrant_connector.delete(

welearn_datastack/nodes_workflow/QdrantSyncronizer/qdrant_syncronizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def main() -> None:
114114

115115
# Iterate on each collection
116116
for collection_name in documents_per_collection:
117+
logger.info(f"We are working on collection : {collection_name}")
117118
# We need to delete all points related to the documents in the collection for avoiding duplicates
118119
del_res = delete_points_related_to_document(
119120
collection_name=collection_name,

0 commit comments

Comments
 (0)