188 lines
6.3 KiB
Python
188 lines
6.3 KiB
Python
import unittest
|
|
|
|
from semantic_index.models import IndexDocument
|
|
from semantic_index.qdrant_store import QdrantStore
|
|
|
|
|
|
class FakeMatchValue:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
|
|
class FakeFieldCondition:
|
|
def __init__(self, key, match=None, range=None):
|
|
self.key = key
|
|
self.match = match
|
|
self.range = range
|
|
|
|
|
|
class FakeFilter:
|
|
def __init__(self, must):
|
|
self.must = must
|
|
|
|
|
|
class FakeFilterSelector:
|
|
def __init__(self, filter):
|
|
self.filter = filter
|
|
|
|
|
|
class FakePointIdsList:
|
|
def __init__(self, points):
|
|
self.points = points
|
|
|
|
|
|
class FakeQModels:
|
|
MatchValue = FakeMatchValue
|
|
FieldCondition = FakeFieldCondition
|
|
Filter = FakeFilter
|
|
FilterSelector = FakeFilterSelector
|
|
PointIdsList = FakePointIdsList
|
|
|
|
class PointStruct:
|
|
def __init__(self, id, vector, payload):
|
|
self.id = id
|
|
self.vector = vector
|
|
self.payload = payload
|
|
|
|
|
|
class FakeCountResult:
|
|
count = 7
|
|
|
|
|
|
class FakeRecord:
|
|
def __init__(self):
|
|
self.id = "point-id"
|
|
self.payload = {
|
|
"document_id": "redmine:issue:1:chunk:0",
|
|
"text": "Indexed text",
|
|
"source": "redmine",
|
|
"project_identifier": "customer-service",
|
|
}
|
|
|
|
|
|
class FakeClient:
|
|
def __init__(self):
|
|
self.count_filter = None
|
|
self.scroll_filter = None
|
|
self.delete_filter = None
|
|
self.delete_selector = None
|
|
self.upsert_batches = []
|
|
|
|
def get_collections(self):
|
|
collection = type("Collection", (), {"name": "semantic"})()
|
|
return type("Collections", (), {"collections": [collection]})()
|
|
|
|
def count(self, collection_name, count_filter, exact):
|
|
self.count_filter = count_filter
|
|
return FakeCountResult()
|
|
|
|
def scroll(self, collection_name, scroll_filter, limit, with_payload, with_vectors, offset=None):
|
|
self.scroll_filter = scroll_filter
|
|
return [FakeRecord()], None
|
|
|
|
def delete(self, collection_name, points_selector):
|
|
self.delete_selector = points_selector
|
|
self.delete_filter = getattr(points_selector, "filter", None)
|
|
|
|
def upsert(self, collection_name, points):
|
|
self.upsert_batches.append(points)
|
|
|
|
|
|
class QdrantStoreReadTest(unittest.TestCase):
|
|
def make_store(self):
|
|
store = object.__new__(QdrantStore)
|
|
store.client = FakeClient()
|
|
store.collection = "semantic"
|
|
store.vector_size = 1536
|
|
store.qmodels = FakeQModels
|
|
store.upsert_batch_size = 2
|
|
return store
|
|
|
|
def test_count_documents_builds_metadata_filter(self):
|
|
store = self.make_store()
|
|
|
|
count = store.count_documents(source="redmine", project_identifier="customer-service", doc_type="issue")
|
|
|
|
self.assertEqual(7, count)
|
|
conditions = store.client.count_filter.must
|
|
self.assertEqual(["source", "project_identifier", "doc_type"], [condition.key for condition in conditions])
|
|
self.assertEqual("customer-service", conditions[1].match.value)
|
|
|
|
def test_list_documents_strips_internal_payload_fields(self):
|
|
store = self.make_store()
|
|
|
|
documents = store.list_documents(limit=5, source="redmine", project_identifier="customer-service")
|
|
|
|
self.assertEqual("redmine:issue:1:chunk:0", documents[0]["id"])
|
|
self.assertEqual("Indexed text", documents[0]["text"])
|
|
self.assertNotIn("document_id", documents[0]["payload"])
|
|
self.assertNotIn("text", documents[0]["payload"])
|
|
|
|
def test_delete_by_source_can_be_limited_to_project_scope(self):
|
|
store = self.make_store()
|
|
|
|
store.delete_by_source("redmine", project_identifier="customer-service")
|
|
|
|
conditions = store.client.delete_filter.must
|
|
self.assertEqual(["source", "project_identifier"], [condition.key for condition in conditions])
|
|
self.assertEqual("redmine", conditions[0].match.value)
|
|
self.assertEqual("customer-service", conditions[1].match.value)
|
|
|
|
def test_list_documents_can_be_limited_to_issue_scope(self):
|
|
store = self.make_store()
|
|
|
|
store.list_documents(limit=5, source="redmine", project_identifier="customer-service", issue_id=39779)
|
|
|
|
conditions = store.client.scroll_filter.must
|
|
self.assertEqual(["source", "project_identifier", "issue_id"], [condition.key for condition in conditions])
|
|
self.assertEqual(39779, conditions[2].match.value)
|
|
|
|
def test_delete_documents_deletes_stable_document_point_ids(self):
|
|
store = self.make_store()
|
|
|
|
store.delete_documents(["redmine:issue:39779:chunk:0"])
|
|
|
|
self.assertEqual(1, len(store.client.delete_selector.points))
|
|
self.assertNotEqual("redmine:issue:39779:chunk:0", store.client.delete_selector.points[0])
|
|
|
|
def test_upsert_sends_points_in_batches(self):
|
|
store = self.make_store()
|
|
documents = [
|
|
IndexDocument(id=f"redmine:issue:{issue_id}:chunk:0", text=f"Issue {issue_id}", payload={"source": "redmine"})
|
|
for issue_id in range(5)
|
|
]
|
|
vectors = [[0.1, 0.2, 0.3] for _ in documents]
|
|
|
|
store.upsert(documents, vectors)
|
|
|
|
self.assertEqual([2, 2, 1], [len(batch) for batch in store.client.upsert_batches])
|
|
self.assertEqual("Issue 0", store.client.upsert_batches[0][0].payload["text"])
|
|
|
|
def test_list_documents_paginates_qdrant_scroll_until_requested_limit(self):
|
|
class PagedClient(FakeClient):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.offsets = []
|
|
|
|
def scroll(self, collection_name, scroll_filter, limit, with_payload, with_vectors, offset=None):
|
|
self.offsets.append(offset)
|
|
first = FakeRecord()
|
|
first.payload = {**first.payload, "document_id": f"doc:{len(self.offsets)}a"}
|
|
second = FakeRecord()
|
|
second.payload = {**second.payload, "document_id": f"doc:{len(self.offsets)}b"}
|
|
if offset is None:
|
|
return [first, second], "next"
|
|
return [first, second], None
|
|
|
|
store = self.make_store()
|
|
store.client = PagedClient()
|
|
|
|
documents = store.list_documents(limit=3, source="redmine")
|
|
|
|
self.assertEqual(["doc:1a", "doc:1b", "doc:2a"], [document["id"] for document in documents])
|
|
self.assertEqual([None, "next"], store.client.offsets)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|