Add semantic-index service, deployment assets, and tests
This commit is contained in:
@@ -0,0 +1,219 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from collections import Counter
|
||||
|
||||
from .models import IndexDocument, SearchQuery, SearchResult
|
||||
|
||||
|
||||
def point_id_for_document(document_id: str) -> str:
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_URL, document_id))
|
||||
|
||||
|
||||
def build_filter(query: SearchQuery) -> Dict[str, List[Dict[str, Any]]]:
|
||||
must: List[Dict[str, Any]] = []
|
||||
equality_fields = {
|
||||
"source": query.source,
|
||||
"project_id": query.project_id,
|
||||
"project_identifier": query.project_identifier,
|
||||
"doc_type": query.doc_type,
|
||||
"issue_id": query.issue_id,
|
||||
"contact_id": query.contact_id,
|
||||
"contact_email": query.contact_email,
|
||||
}
|
||||
for key, value in equality_fields.items():
|
||||
if value is not None:
|
||||
must.append({"key": key, "match": {"value": value}})
|
||||
if query.date_from or query.date_to:
|
||||
range_filter: Dict[str, str] = {}
|
||||
if query.date_from:
|
||||
range_filter["gte"] = query.date_from
|
||||
if query.date_to:
|
||||
range_filter["lte"] = query.date_to
|
||||
must.append({"key": "created_on", "range": range_filter})
|
||||
return {"must": must}
|
||||
|
||||
|
||||
class QdrantStore:
|
||||
def __init__(self, url: str, api_key: Optional[str], collection: str, vector_size: int = 1536, upsert_batch_size: int = 64) -> None:
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as qmodels
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("Install qdrant-client to use live Qdrant storage") from exc
|
||||
self.client = QdrantClient(url=url, api_key=api_key)
|
||||
self.collection = collection
|
||||
self.vector_size = vector_size
|
||||
self.upsert_batch_size = upsert_batch_size
|
||||
self.qmodels = qmodels
|
||||
|
||||
def ensure_collection(self) -> None:
|
||||
collections = self.client.get_collections().collections
|
||||
if any(collection.name == self.collection for collection in collections):
|
||||
return
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection,
|
||||
vectors_config=self.qmodels.VectorParams(size=self.vector_size, distance=self.qmodels.Distance.COSINE),
|
||||
)
|
||||
|
||||
def upsert(self, documents: Sequence[IndexDocument], vectors: Sequence[Sequence[float]]) -> None:
|
||||
if len(documents) != len(vectors):
|
||||
raise ValueError("documents and vectors length mismatch")
|
||||
self.ensure_collection()
|
||||
points = [
|
||||
self.qmodels.PointStruct(
|
||||
id=point_id_for_document(document.id),
|
||||
vector=list(vector),
|
||||
payload={**document.payload, "document_id": document.id, "text": document.text},
|
||||
)
|
||||
for document, vector in zip(documents, vectors)
|
||||
]
|
||||
for start in range(0, len(points), self.upsert_batch_size):
|
||||
batch = points[start : start + self.upsert_batch_size]
|
||||
if batch:
|
||||
self.client.upsert(collection_name=self.collection, points=batch)
|
||||
|
||||
def delete_by_source(self, source: str, project_identifier: Optional[str] = None) -> None:
|
||||
self.ensure_collection()
|
||||
query = SearchQuery(text="*", source=source, project_identifier=project_identifier)
|
||||
self.client.delete(
|
||||
collection_name=self.collection,
|
||||
points_selector=self.qmodels.FilterSelector(
|
||||
filter=self._to_qdrant_filter(build_filter(query))
|
||||
),
|
||||
)
|
||||
|
||||
def delete_documents(self, document_ids: Sequence[str]) -> None:
|
||||
self.ensure_collection()
|
||||
if not document_ids:
|
||||
return
|
||||
self.client.delete(
|
||||
collection_name=self.collection,
|
||||
points_selector=self.qmodels.PointIdsList(
|
||||
points=[point_id_for_document(document_id) for document_id in document_ids]
|
||||
),
|
||||
)
|
||||
|
||||
def rebuild_source(
|
||||
self,
|
||||
source: str,
|
||||
documents: Sequence[IndexDocument],
|
||||
vectors: Sequence[Sequence[float]],
|
||||
project_identifier: Optional[str] = None,
|
||||
) -> None:
|
||||
self.delete_by_source(source, project_identifier=project_identifier)
|
||||
self.upsert(documents, vectors)
|
||||
|
||||
def search(self, vector: Sequence[float], query: SearchQuery, limit: int) -> List[SearchResult]:
|
||||
self.ensure_collection()
|
||||
qfilter = self._to_qdrant_filter(build_filter(query))
|
||||
if hasattr(self.client, "query_points"):
|
||||
response = self.client.query_points(
|
||||
collection_name=self.collection,
|
||||
query=list(vector),
|
||||
query_filter=qfilter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
)
|
||||
results = response.points
|
||||
else:
|
||||
results = self.client.search(
|
||||
collection_name=self.collection,
|
||||
query_vector=list(vector),
|
||||
query_filter=qfilter,
|
||||
limit=limit,
|
||||
with_payload=True,
|
||||
)
|
||||
return [self._point_to_result(point) for point in results]
|
||||
|
||||
def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
|
||||
self.ensure_collection()
|
||||
points = self.client.retrieve(collection_name=self.collection, ids=[point_id_for_document(document_id)], with_payload=True)
|
||||
if not points:
|
||||
return None
|
||||
payload = dict(points[0].payload or {})
|
||||
text = payload.pop("text", "")
|
||||
payload.pop("document_id", None)
|
||||
return {"id": document_id, "text": text, "payload": payload}
|
||||
|
||||
def count_documents(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
project_identifier: Optional[str] = None,
|
||||
doc_type: Optional[str] = None,
|
||||
) -> int:
|
||||
self.ensure_collection()
|
||||
query = SearchQuery(text="*", source=source, project_identifier=project_identifier, doc_type=doc_type)
|
||||
result = self.client.count(
|
||||
collection_name=self.collection,
|
||||
count_filter=self._to_qdrant_filter(build_filter(query)),
|
||||
exact=True,
|
||||
)
|
||||
return int(result.count)
|
||||
|
||||
def list_documents(
|
||||
self,
|
||||
limit: int = 10,
|
||||
source: Optional[str] = None,
|
||||
project_identifier: Optional[str] = None,
|
||||
doc_type: Optional[str] = None,
|
||||
issue_id: Optional[int] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
self.ensure_collection()
|
||||
query = SearchQuery(text="*", source=source, project_identifier=project_identifier, doc_type=doc_type, issue_id=issue_id)
|
||||
qfilter = self._to_qdrant_filter(build_filter(query))
|
||||
records = []
|
||||
offset = None
|
||||
while len(records) < limit:
|
||||
batch_limit = limit - len(records)
|
||||
batch, offset = self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
scroll_filter=qfilter,
|
||||
limit=batch_limit,
|
||||
with_payload=True,
|
||||
with_vectors=False,
|
||||
offset=offset,
|
||||
)
|
||||
records.extend(batch[:batch_limit])
|
||||
if not offset or not batch:
|
||||
break
|
||||
return [self._record_to_document(record) for record in records]
|
||||
|
||||
def list_projects(self, source: Optional[str] = None, limit: int = 5000) -> List[Dict[str, Any]]:
|
||||
documents = self.list_documents(limit=limit, source=source)
|
||||
counts = Counter(
|
||||
str((document.get("payload") or {}).get("project_identifier"))
|
||||
for document in documents
|
||||
if (document.get("payload") or {}).get("project_identifier")
|
||||
)
|
||||
return [
|
||||
{"project_identifier": project, "document_count": count}
|
||||
for project, count in sorted(counts.items())
|
||||
]
|
||||
|
||||
def _to_qdrant_filter(self, raw_filter: Dict[str, List[Dict[str, Any]]]) -> Any:
|
||||
conditions = []
|
||||
for condition in raw_filter.get("must", []):
|
||||
if "match" in condition:
|
||||
conditions.append(
|
||||
self.qmodels.FieldCondition(
|
||||
key=condition["key"],
|
||||
match=self.qmodels.MatchValue(value=condition["match"]["value"]),
|
||||
)
|
||||
)
|
||||
elif "range" in condition:
|
||||
conditions.append(self.qmodels.FieldCondition(key=condition["key"], range=self.qmodels.DatetimeRange(**condition["range"])))
|
||||
return self.qmodels.Filter(must=conditions) if conditions else None
|
||||
|
||||
def _point_to_result(self, point: Any) -> SearchResult:
|
||||
payload = dict(point.payload or {})
|
||||
text = payload.pop("text", "")
|
||||
document_id = payload.pop("document_id", str(point.id))
|
||||
return SearchResult(id=document_id, score=float(point.score), text=text, payload=payload)
|
||||
|
||||
def _record_to_document(self, record: Any) -> Dict[str, Any]:
|
||||
payload = dict(record.payload or {})
|
||||
text = payload.pop("text", "")
|
||||
document_id = payload.pop("document_id", str(record.id))
|
||||
return {"id": document_id, "text": text, "payload": payload}
|
||||
Reference in New Issue
Block a user