Files
2026-05-04 09:50:03 -04:00

65 lines
2.3 KiB
Python

from __future__ import annotations
from typing import Iterable, List, Optional, Protocol, Sequence
from .models import IndexDocument
class EmbeddingClient(Protocol):
def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]:
...
class OpenAIEmbeddingClient:
def __init__(self, api_key: Optional[str] = None) -> None:
try:
from openai import OpenAI
except ImportError as exc:
raise RuntimeError("Install openai to use live embeddings") from exc
self.client = OpenAI(api_key=api_key)
def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]:
kwargs = {"model": model, "input": list(inputs)}
if dimensions is not None:
kwargs["dimensions"] = dimensions
response = self.client.embeddings.create(**kwargs)
return [item.embedding for item in response.data]
class OpenAIEmbedder:
def __init__(
self,
client: EmbeddingClient,
model: str = "text-embedding-3-small",
dimensions: int = 1536,
batch_size: int = 100,
max_chars: int = 12000,
) -> None:
self.client = client
self.model = model
self.dimensions = dimensions
self.batch_size = batch_size
self.max_chars = max_chars
def embed_documents(self, documents: Sequence[IndexDocument]) -> List[List[float]]:
return self.embed_texts([document.text for document in documents])
def embed_query(self, text: str) -> List[float]:
return self.embed_texts([text])[0]
def embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
values = list(texts)
self._validate(values)
vectors: List[List[float]] = []
for start in range(0, len(values), self.batch_size):
batch = values[start : start + self.batch_size]
vectors.extend(self.client.create_embeddings(self.model, batch, dimensions=self.dimensions))
return vectors
def _validate(self, texts: Sequence[str]) -> None:
for text in texts:
if not text.strip():
raise ValueError("embedding text cannot be empty")
if len(text) > self.max_chars:
raise ValueError(f"embedding text exceeds {self.max_chars} characters")