from __future__ import annotations from typing import Iterable, List, Optional, Protocol, Sequence from .models import IndexDocument class EmbeddingClient(Protocol): def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]: ... class OpenAIEmbeddingClient: def __init__(self, api_key: Optional[str] = None) -> None: try: from openai import OpenAI except ImportError as exc: raise RuntimeError("Install openai to use live embeddings") from exc self.client = OpenAI(api_key=api_key) def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]: kwargs = {"model": model, "input": list(inputs)} if dimensions is not None: kwargs["dimensions"] = dimensions response = self.client.embeddings.create(**kwargs) return [item.embedding for item in response.data] class OpenAIEmbedder: def __init__( self, client: EmbeddingClient, model: str = "text-embedding-3-small", dimensions: int = 1536, batch_size: int = 100, max_chars: int = 12000, ) -> None: self.client = client self.model = model self.dimensions = dimensions self.batch_size = batch_size self.max_chars = max_chars def embed_documents(self, documents: Sequence[IndexDocument]) -> List[List[float]]: return self.embed_texts([document.text for document in documents]) def embed_query(self, text: str) -> List[float]: return self.embed_texts([text])[0] def embed_texts(self, texts: Iterable[str]) -> List[List[float]]: values = list(texts) self._validate(values) vectors: List[List[float]] = [] for start in range(0, len(values), self.batch_size): batch = values[start : start + self.batch_size] vectors.extend(self.client.create_embeddings(self.model, batch, dimensions=self.dimensions)) return vectors def _validate(self, texts: Sequence[str]) -> None: for text in texts: if not text.strip(): raise ValueError("embedding text cannot be empty") if len(text) > self.max_chars: raise ValueError(f"embedding text exceeds {self.max_chars} characters")