Files
2026-05-04 09:50:03 -04:00

220 lines
8.9 KiB
Python

from __future__ import annotations
import uuid
from typing import Any, Dict, List, Optional, Sequence
from collections import Counter
from .models import IndexDocument, SearchQuery, SearchResult
def point_id_for_document(document_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_URL, document_id))
def build_filter(query: SearchQuery) -> Dict[str, List[Dict[str, Any]]]:
must: List[Dict[str, Any]] = []
equality_fields = {
"source": query.source,
"project_id": query.project_id,
"project_identifier": query.project_identifier,
"doc_type": query.doc_type,
"issue_id": query.issue_id,
"contact_id": query.contact_id,
"contact_email": query.contact_email,
}
for key, value in equality_fields.items():
if value is not None:
must.append({"key": key, "match": {"value": value}})
if query.date_from or query.date_to:
range_filter: Dict[str, str] = {}
if query.date_from:
range_filter["gte"] = query.date_from
if query.date_to:
range_filter["lte"] = query.date_to
must.append({"key": "created_on", "range": range_filter})
return {"must": must}
class QdrantStore:
def __init__(self, url: str, api_key: Optional[str], collection: str, vector_size: int = 1536, upsert_batch_size: int = 64) -> None:
try:
from qdrant_client import QdrantClient
from qdrant_client.http import models as qmodels
except ImportError as exc:
raise RuntimeError("Install qdrant-client to use live Qdrant storage") from exc
self.client = QdrantClient(url=url, api_key=api_key)
self.collection = collection
self.vector_size = vector_size
self.upsert_batch_size = upsert_batch_size
self.qmodels = qmodels
def ensure_collection(self) -> None:
collections = self.client.get_collections().collections
if any(collection.name == self.collection for collection in collections):
return
self.client.create_collection(
collection_name=self.collection,
vectors_config=self.qmodels.VectorParams(size=self.vector_size, distance=self.qmodels.Distance.COSINE),
)
def upsert(self, documents: Sequence[IndexDocument], vectors: Sequence[Sequence[float]]) -> None:
if len(documents) != len(vectors):
raise ValueError("documents and vectors length mismatch")
self.ensure_collection()
points = [
self.qmodels.PointStruct(
id=point_id_for_document(document.id),
vector=list(vector),
payload={**document.payload, "document_id": document.id, "text": document.text},
)
for document, vector in zip(documents, vectors)
]
for start in range(0, len(points), self.upsert_batch_size):
batch = points[start : start + self.upsert_batch_size]
if batch:
self.client.upsert(collection_name=self.collection, points=batch)
def delete_by_source(self, source: str, project_identifier: Optional[str] = None) -> None:
self.ensure_collection()
query = SearchQuery(text="*", source=source, project_identifier=project_identifier)
self.client.delete(
collection_name=self.collection,
points_selector=self.qmodels.FilterSelector(
filter=self._to_qdrant_filter(build_filter(query))
),
)
def delete_documents(self, document_ids: Sequence[str]) -> None:
self.ensure_collection()
if not document_ids:
return
self.client.delete(
collection_name=self.collection,
points_selector=self.qmodels.PointIdsList(
points=[point_id_for_document(document_id) for document_id in document_ids]
),
)
def rebuild_source(
self,
source: str,
documents: Sequence[IndexDocument],
vectors: Sequence[Sequence[float]],
project_identifier: Optional[str] = None,
) -> None:
self.delete_by_source(source, project_identifier=project_identifier)
self.upsert(documents, vectors)
def search(self, vector: Sequence[float], query: SearchQuery, limit: int) -> List[SearchResult]:
self.ensure_collection()
qfilter = self._to_qdrant_filter(build_filter(query))
if hasattr(self.client, "query_points"):
response = self.client.query_points(
collection_name=self.collection,
query=list(vector),
query_filter=qfilter,
limit=limit,
with_payload=True,
)
results = response.points
else:
results = self.client.search(
collection_name=self.collection,
query_vector=list(vector),
query_filter=qfilter,
limit=limit,
with_payload=True,
)
return [self._point_to_result(point) for point in results]
def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
self.ensure_collection()
points = self.client.retrieve(collection_name=self.collection, ids=[point_id_for_document(document_id)], with_payload=True)
if not points:
return None
payload = dict(points[0].payload or {})
text = payload.pop("text", "")
payload.pop("document_id", None)
return {"id": document_id, "text": text, "payload": payload}
def count_documents(
self,
source: Optional[str] = None,
project_identifier: Optional[str] = None,
doc_type: Optional[str] = None,
) -> int:
self.ensure_collection()
query = SearchQuery(text="*", source=source, project_identifier=project_identifier, doc_type=doc_type)
result = self.client.count(
collection_name=self.collection,
count_filter=self._to_qdrant_filter(build_filter(query)),
exact=True,
)
return int(result.count)
def list_documents(
self,
limit: int = 10,
source: Optional[str] = None,
project_identifier: Optional[str] = None,
doc_type: Optional[str] = None,
issue_id: Optional[int] = None,
) -> List[Dict[str, Any]]:
self.ensure_collection()
query = SearchQuery(text="*", source=source, project_identifier=project_identifier, doc_type=doc_type, issue_id=issue_id)
qfilter = self._to_qdrant_filter(build_filter(query))
records = []
offset = None
while len(records) < limit:
batch_limit = limit - len(records)
batch, offset = self.client.scroll(
collection_name=self.collection,
scroll_filter=qfilter,
limit=batch_limit,
with_payload=True,
with_vectors=False,
offset=offset,
)
records.extend(batch[:batch_limit])
if not offset or not batch:
break
return [self._record_to_document(record) for record in records]
def list_projects(self, source: Optional[str] = None, limit: int = 5000) -> List[Dict[str, Any]]:
documents = self.list_documents(limit=limit, source=source)
counts = Counter(
str((document.get("payload") or {}).get("project_identifier"))
for document in documents
if (document.get("payload") or {}).get("project_identifier")
)
return [
{"project_identifier": project, "document_count": count}
for project, count in sorted(counts.items())
]
def _to_qdrant_filter(self, raw_filter: Dict[str, List[Dict[str, Any]]]) -> Any:
conditions = []
for condition in raw_filter.get("must", []):
if "match" in condition:
conditions.append(
self.qmodels.FieldCondition(
key=condition["key"],
match=self.qmodels.MatchValue(value=condition["match"]["value"]),
)
)
elif "range" in condition:
conditions.append(self.qmodels.FieldCondition(key=condition["key"], range=self.qmodels.DatetimeRange(**condition["range"])))
return self.qmodels.Filter(must=conditions) if conditions else None
def _point_to_result(self, point: Any) -> SearchResult:
payload = dict(point.payload or {})
text = payload.pop("text", "")
document_id = payload.pop("document_id", str(point.id))
return SearchResult(id=document_id, score=float(point.score), text=text, payload=payload)
def _record_to_document(self, record: Any) -> Dict[str, Any]:
payload = dict(record.payload or {})
text = payload.pop("text", "")
document_id = payload.pop("document_id", str(record.id))
return {"id": document_id, "text": text, "payload": payload}