from __future__ import annotations import uuid from typing import Any, Dict, List, Optional, Sequence from collections import Counter from .models import IndexDocument, SearchQuery, SearchResult def point_id_for_document(document_id: str) -> str: return str(uuid.uuid5(uuid.NAMESPACE_URL, document_id)) def build_filter(query: SearchQuery) -> Dict[str, List[Dict[str, Any]]]: must: List[Dict[str, Any]] = [] equality_fields = { "source": query.source, "project_id": query.project_id, "project_identifier": query.project_identifier, "doc_type": query.doc_type, "issue_id": query.issue_id, "contact_id": query.contact_id, "contact_email": query.contact_email, } for key, value in equality_fields.items(): if value is not None: must.append({"key": key, "match": {"value": value}}) if query.date_from or query.date_to: range_filter: Dict[str, str] = {} if query.date_from: range_filter["gte"] = query.date_from if query.date_to: range_filter["lte"] = query.date_to must.append({"key": "created_on", "range": range_filter}) return {"must": must} class QdrantStore: def __init__(self, url: str, api_key: Optional[str], collection: str, vector_size: int = 1536, upsert_batch_size: int = 64) -> None: try: from qdrant_client import QdrantClient from qdrant_client.http import models as qmodels except ImportError as exc: raise RuntimeError("Install qdrant-client to use live Qdrant storage") from exc self.client = QdrantClient(url=url, api_key=api_key) self.collection = collection self.vector_size = vector_size self.upsert_batch_size = upsert_batch_size self.qmodels = qmodels def ensure_collection(self) -> None: collections = self.client.get_collections().collections if any(collection.name == self.collection for collection in collections): return self.client.create_collection( collection_name=self.collection, vectors_config=self.qmodels.VectorParams(size=self.vector_size, distance=self.qmodels.Distance.COSINE), ) def upsert(self, documents: Sequence[IndexDocument], vectors: Sequence[Sequence[float]]) -> None: if len(documents) != len(vectors): raise ValueError("documents and vectors length mismatch") self.ensure_collection() points = [ self.qmodels.PointStruct( id=point_id_for_document(document.id), vector=list(vector), payload={**document.payload, "document_id": document.id, "text": document.text}, ) for document, vector in zip(documents, vectors) ] for start in range(0, len(points), self.upsert_batch_size): batch = points[start : start + self.upsert_batch_size] if batch: self.client.upsert(collection_name=self.collection, points=batch) def delete_by_source(self, source: str, project_identifier: Optional[str] = None) -> None: self.ensure_collection() query = SearchQuery(text="*", source=source, project_identifier=project_identifier) self.client.delete( collection_name=self.collection, points_selector=self.qmodels.FilterSelector( filter=self._to_qdrant_filter(build_filter(query)) ), ) def delete_documents(self, document_ids: Sequence[str]) -> None: self.ensure_collection() if not document_ids: return self.client.delete( collection_name=self.collection, points_selector=self.qmodels.PointIdsList( points=[point_id_for_document(document_id) for document_id in document_ids] ), ) def rebuild_source( self, source: str, documents: Sequence[IndexDocument], vectors: Sequence[Sequence[float]], project_identifier: Optional[str] = None, ) -> None: self.delete_by_source(source, project_identifier=project_identifier) self.upsert(documents, vectors) def search(self, vector: Sequence[float], query: SearchQuery, limit: int) -> List[SearchResult]: self.ensure_collection() qfilter = self._to_qdrant_filter(build_filter(query)) if hasattr(self.client, "query_points"): response = self.client.query_points( collection_name=self.collection, query=list(vector), query_filter=qfilter, limit=limit, with_payload=True, ) results = response.points else: results = self.client.search( collection_name=self.collection, query_vector=list(vector), query_filter=qfilter, limit=limit, with_payload=True, ) return [self._point_to_result(point) for point in results] def get_document(self, document_id: str) -> Optional[Dict[str, Any]]: self.ensure_collection() points = self.client.retrieve(collection_name=self.collection, ids=[point_id_for_document(document_id)], with_payload=True) if not points: return None payload = dict(points[0].payload or {}) text = payload.pop("text", "") payload.pop("document_id", None) return {"id": document_id, "text": text, "payload": payload} def count_documents( self, source: Optional[str] = None, project_identifier: Optional[str] = None, doc_type: Optional[str] = None, ) -> int: self.ensure_collection() query = SearchQuery(text="*", source=source, project_identifier=project_identifier, doc_type=doc_type) result = self.client.count( collection_name=self.collection, count_filter=self._to_qdrant_filter(build_filter(query)), exact=True, ) return int(result.count) def list_documents( self, limit: int = 10, source: Optional[str] = None, project_identifier: Optional[str] = None, doc_type: Optional[str] = None, issue_id: Optional[int] = None, ) -> List[Dict[str, Any]]: self.ensure_collection() query = SearchQuery(text="*", source=source, project_identifier=project_identifier, doc_type=doc_type, issue_id=issue_id) qfilter = self._to_qdrant_filter(build_filter(query)) records = [] offset = None while len(records) < limit: batch_limit = limit - len(records) batch, offset = self.client.scroll( collection_name=self.collection, scroll_filter=qfilter, limit=batch_limit, with_payload=True, with_vectors=False, offset=offset, ) records.extend(batch[:batch_limit]) if not offset or not batch: break return [self._record_to_document(record) for record in records] def list_projects(self, source: Optional[str] = None, limit: int = 5000) -> List[Dict[str, Any]]: documents = self.list_documents(limit=limit, source=source) counts = Counter( str((document.get("payload") or {}).get("project_identifier")) for document in documents if (document.get("payload") or {}).get("project_identifier") ) return [ {"project_identifier": project, "document_count": count} for project, count in sorted(counts.items()) ] def _to_qdrant_filter(self, raw_filter: Dict[str, List[Dict[str, Any]]]) -> Any: conditions = [] for condition in raw_filter.get("must", []): if "match" in condition: conditions.append( self.qmodels.FieldCondition( key=condition["key"], match=self.qmodels.MatchValue(value=condition["match"]["value"]), ) ) elif "range" in condition: conditions.append(self.qmodels.FieldCondition(key=condition["key"], range=self.qmodels.DatetimeRange(**condition["range"]))) return self.qmodels.Filter(must=conditions) if conditions else None def _point_to_result(self, point: Any) -> SearchResult: payload = dict(point.payload or {}) text = payload.pop("text", "") document_id = payload.pop("document_id", str(point.id)) return SearchResult(id=document_id, score=float(point.score), text=text, payload=payload) def _record_to_document(self, record: Any) -> Dict[str, Any]: payload = dict(record.payload or {}) text = payload.pop("text", "") document_id = payload.pop("document_id", str(record.id)) return {"id": document_id, "text": text, "payload": payload}