47 lines
1.6 KiB
Python
47 lines
1.6 KiB
Python
import unittest
|
|
|
|
from semantic_index.embeddings import OpenAIEmbedder
|
|
from semantic_index.models import IndexDocument
|
|
|
|
|
|
class FakeOpenAIClient:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def create_embeddings(self, model, inputs, dimensions=None):
|
|
self.calls.append({"model": model, "inputs": list(inputs), "dimensions": dimensions})
|
|
return [[float(i)] * 3 for i, _ in enumerate(inputs, start=1)]
|
|
|
|
|
|
class OpenAIEmbedderTest(unittest.TestCase):
|
|
def test_batches_embedding_requests(self):
|
|
client = FakeOpenAIClient()
|
|
embedder = OpenAIEmbedder(client=client, batch_size=2, dimensions=1536)
|
|
docs = [
|
|
IndexDocument(id="a", text="alpha", payload={}),
|
|
IndexDocument(id="b", text="bravo", payload={}),
|
|
IndexDocument(id="c", text="charlie", payload={}),
|
|
]
|
|
|
|
vectors = embedder.embed_documents(docs)
|
|
|
|
self.assertEqual([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]], vectors)
|
|
self.assertEqual(2, len(client.calls))
|
|
self.assertEqual(["alpha", "bravo"], client.calls[0]["inputs"])
|
|
self.assertEqual("text-embedding-3-small", client.calls[0]["model"])
|
|
self.assertEqual(1536, client.calls[0]["dimensions"])
|
|
|
|
def test_rejects_empty_or_oversized_chunks_before_api_call(self):
|
|
client = FakeOpenAIClient()
|
|
embedder = OpenAIEmbedder(client=client, max_chars=5)
|
|
|
|
with self.assertRaises(ValueError):
|
|
embedder.embed_texts(["ok", " "])
|
|
with self.assertRaises(ValueError):
|
|
embedder.embed_texts(["toolong"])
|
|
self.assertEqual([], client.calls)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|