Add semantic-index service, deployment assets, and tests
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable, List, Optional, Protocol, Sequence
|
||||
|
||||
from .models import IndexDocument
|
||||
|
||||
|
||||
class EmbeddingClient(Protocol):
|
||||
def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]:
|
||||
...
|
||||
|
||||
|
||||
class OpenAIEmbeddingClient:
|
||||
def __init__(self, api_key: Optional[str] = None) -> None:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("Install openai to use live embeddings") from exc
|
||||
self.client = OpenAI(api_key=api_key)
|
||||
|
||||
def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]:
|
||||
kwargs = {"model": model, "input": list(inputs)}
|
||||
if dimensions is not None:
|
||||
kwargs["dimensions"] = dimensions
|
||||
response = self.client.embeddings.create(**kwargs)
|
||||
return [item.embedding for item in response.data]
|
||||
|
||||
|
||||
class OpenAIEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
client: EmbeddingClient,
|
||||
model: str = "text-embedding-3-small",
|
||||
dimensions: int = 1536,
|
||||
batch_size: int = 100,
|
||||
max_chars: int = 12000,
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.model = model
|
||||
self.dimensions = dimensions
|
||||
self.batch_size = batch_size
|
||||
self.max_chars = max_chars
|
||||
|
||||
def embed_documents(self, documents: Sequence[IndexDocument]) -> List[List[float]]:
|
||||
return self.embed_texts([document.text for document in documents])
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_texts([text])[0]
|
||||
|
||||
def embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
|
||||
values = list(texts)
|
||||
self._validate(values)
|
||||
vectors: List[List[float]] = []
|
||||
for start in range(0, len(values), self.batch_size):
|
||||
batch = values[start : start + self.batch_size]
|
||||
vectors.extend(self.client.create_embeddings(self.model, batch, dimensions=self.dimensions))
|
||||
return vectors
|
||||
|
||||
def _validate(self, texts: Sequence[str]) -> None:
|
||||
for text in texts:
|
||||
if not text.strip():
|
||||
raise ValueError("embedding text cannot be empty")
|
||||
if len(text) > self.max_chars:
|
||||
raise ValueError(f"embedding text exceeds {self.max_chars} characters")
|
||||
Reference in New Issue
Block a user