62 lines
2.2 KiB
Python
62 lines
2.2 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Protocol
|
|
|
|
from .models import SearchQuery, SearchResult
|
|
|
|
|
|
class QueryEmbedder(Protocol):
|
|
def embed_query(self, text: str) -> List[float]:
|
|
...
|
|
|
|
|
|
class SearchStore(Protocol):
|
|
def search(self, vector: List[float], query: SearchQuery, limit: int) -> List[SearchResult]:
|
|
...
|
|
|
|
def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
|
|
...
|
|
|
|
|
|
class HybridSearchService:
|
|
def __init__(self, embedder: QueryEmbedder, store: SearchStore) -> None:
|
|
self.embedder = embedder
|
|
self.store = store
|
|
|
|
def search(self, query: SearchQuery) -> List[SearchResult]:
|
|
vector = self.embedder.embed_query(query.text)
|
|
candidates = self.store.search(vector, query, limit=query.limit)
|
|
rescored = [
|
|
SearchResult(
|
|
id=result.id,
|
|
score=result.score + keyword_boost(query.text, result),
|
|
text=result.text,
|
|
payload=result.payload,
|
|
)
|
|
for result in candidates
|
|
]
|
|
return sorted(rescored, key=lambda result: result.score, reverse=True)[: query.limit]
|
|
|
|
def get_document(self, document_id: str) -> Optional[Dict[str, Any]]:
|
|
return self.store.get_document(document_id)
|
|
|
|
|
|
def keyword_boost(query_text: str, result: SearchResult) -> float:
|
|
haystack = " ".join([result.text, " ".join(str(value) for value in result.payload.values() if value is not None)]).lower()
|
|
boost = 0.0
|
|
for phrase in re.findall(r'"([^"]+)"', query_text):
|
|
if phrase.lower() in haystack:
|
|
boost += 0.35
|
|
for email in re.findall(r"[\w.+-]+@[\w.-]+\.[A-Za-z]{2,}", query_text):
|
|
if email.lower() in haystack:
|
|
boost += 0.3
|
|
for token in re.findall(r"\b(?:#?\d{2,}|[A-Z]{2,}[-_]\d{2,}|[A-Z0-9]{4,}-[A-Z0-9-]{2,})\b", query_text):
|
|
normalized = token.lower().lstrip("#")
|
|
if token.lower() in haystack or normalized in haystack:
|
|
boost += 0.25
|
|
for word in re.findall(r"\b[A-Za-z][\w.-]{2,}\b", query_text):
|
|
if word.lower() in haystack:
|
|
boost += 0.03
|
|
return boost
|