65 lines
2.3 KiB
Python
65 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Iterable, List, Optional, Protocol, Sequence
|
|
|
|
from .models import IndexDocument
|
|
|
|
|
|
class EmbeddingClient(Protocol):
|
|
def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]:
|
|
...
|
|
|
|
|
|
class OpenAIEmbeddingClient:
|
|
def __init__(self, api_key: Optional[str] = None) -> None:
|
|
try:
|
|
from openai import OpenAI
|
|
except ImportError as exc:
|
|
raise RuntimeError("Install openai to use live embeddings") from exc
|
|
self.client = OpenAI(api_key=api_key)
|
|
|
|
def create_embeddings(self, model: str, inputs: Sequence[str], dimensions: Optional[int] = None) -> List[List[float]]:
|
|
kwargs = {"model": model, "input": list(inputs)}
|
|
if dimensions is not None:
|
|
kwargs["dimensions"] = dimensions
|
|
response = self.client.embeddings.create(**kwargs)
|
|
return [item.embedding for item in response.data]
|
|
|
|
|
|
class OpenAIEmbedder:
|
|
def __init__(
|
|
self,
|
|
client: EmbeddingClient,
|
|
model: str = "text-embedding-3-small",
|
|
dimensions: int = 1536,
|
|
batch_size: int = 100,
|
|
max_chars: int = 12000,
|
|
) -> None:
|
|
self.client = client
|
|
self.model = model
|
|
self.dimensions = dimensions
|
|
self.batch_size = batch_size
|
|
self.max_chars = max_chars
|
|
|
|
def embed_documents(self, documents: Sequence[IndexDocument]) -> List[List[float]]:
|
|
return self.embed_texts([document.text for document in documents])
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
return self.embed_texts([text])[0]
|
|
|
|
def embed_texts(self, texts: Iterable[str]) -> List[List[float]]:
|
|
values = list(texts)
|
|
self._validate(values)
|
|
vectors: List[List[float]] = []
|
|
for start in range(0, len(values), self.batch_size):
|
|
batch = values[start : start + self.batch_size]
|
|
vectors.extend(self.client.create_embeddings(self.model, batch, dimensions=self.dimensions))
|
|
return vectors
|
|
|
|
def _validate(self, texts: Sequence[str]) -> None:
|
|
for text in texts:
|
|
if not text.strip():
|
|
raise ValueError("embedding text cannot be empty")
|
|
if len(text) > self.max_chars:
|
|
raise ValueError(f"embedding text exceeds {self.max_chars} characters")
|