diff --git a/CHANGES.md b/CHANGES.md index ca911e52a7ad..fa4f6491866d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -82,6 +82,7 @@ * Added plugin mechanism to support different Lineage implementations (Java) ([#36790](https://github.com/apache/beam/issues/36790)). * (Python) Supported Python user type in Beam SQL. For example, SQL statements like `SELECT some_field from PCOLLECTION` can now operate a PCollection of Beam Row containing pickable Python user type ([#20738](https://github.com/apache/beam/issues/20738)). * (Python) Introduced `beam.coders.registry.register_row` as preferred API to register a named tuple or dataclass with a Beam Row. At pipelne runtime, the original type associated with the registered row are preserved across the serialization boundary ([#38108](https://github.com/apache/beam/issues/38108)). +* (Python) Added [Qdrant](https://qdrant.tech/) VectorDatabaseWriteConfig implementation ([#38141](https://github.com/apache/beam/issues/38141)). ## Breaking Changes diff --git a/sdks/python/apache_beam/ml/rag/ingestion/qdrant.py b/sdks/python/apache_beam/ml/rag/ingestion/qdrant.py new file mode 100644 index 000000000000..1f8fdb31c983 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/qdrant.py @@ -0,0 +1,326 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from collections.abc import Callable +from dataclasses import dataclass +from dataclasses import field +from typing import Any +from typing import Optional + +import grpc +from objsize import get_deep_size + +try: + from qdrant_client import QdrantClient + from qdrant_client import models + from qdrant_client.common.client_exceptions import ResourceExhaustedResponse + from qdrant_client.http.exceptions import ResponseHandlingException + from qdrant_client.http.exceptions import UnexpectedResponse +except ImportError: + logging.warning("Qdrant client library is not installed.") + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.types import EmbeddableItem + +DEFAULT_WRITE_BATCH_SIZE = 1000 +DEFAULT_MAX_BATCH_BYTE_SIZE = 4 << 20 + + +@dataclass +class QdrantConnectionParameters: + """Configuration parameters for connecting to Qdrant service. + + Either `location`, `url`, `host`, or `path` must be provided to establish + a connection. + + Args: + location: + If `str` - use it as a `url` parameter. + If `None` - use default values for `host` and `port`. + url: either host or str of "//:/". + Default: `None` + port: Port of the REST API interface. Default: 6333 + grpc_port: Port of the gRPC interface. Default: 6334 + prefer_grpc: If `true` - use gPRC interface whenever possible. + https: If `true` - use HTTPS(SSL) protocol. Default: `None` + api_key: API key for authentication in Qdrant Cloud. Default: `None` + prefix: + If not `None` - add `prefix` to the REST URL path. + Example: `service/v1` will result in + `http://localhost:6333/service/v1/{qdrant-endpoint}` for REST API. + Default: `None` + timeout: + Timeout for REST and gRPC API requests. + Default: 5 seconds for REST and unlimited for gRPC + host: + Host name of Qdrant service. + If url and host are None, set to 'localhost'. + Default: `None` + path: Persistence path for QdrantLocal. Default: `None` + **kwargs: Additional arguments passed directly into client initialization + """ + + location: Optional[str] = None + url: Optional[str] = None + port: Optional[int] = 6333 + grpc_port: int = 6334 + prefer_grpc: bool = False + https: Optional[bool] = None + api_key: Optional[str] = None + prefix: Optional[str] = None + timeout: Optional[int] = None + host: Optional[str] = None + path: Optional[str] = None + kwargs: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not (self.location or self.url or self.host or self.path): + raise ValueError( + "One of location, url, host, or path must be provided for Qdrant") + + @classmethod + def for_cloud( + cls, + url: str, + api_key: str, + *, + prefer_grpc: bool = False, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> "QdrantConnectionParameters": + """Connect to Qdrant Cloud. Requires the cluster URL and an API key.""" + return cls( + url=url, + api_key=api_key, + https=True, + prefer_grpc=prefer_grpc, + timeout=timeout, + kwargs=kwargs, + ) + + @classmethod + def for_host( + cls, + host: str, + port: int = 6333, + *, + grpc_port: int = 6334, + prefer_grpc: bool = False, + https: bool = False, + api_key: Optional[str] = None, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> "QdrantConnectionParameters": + """Connect to a self-hosted Qdrant instance by host and port.""" + return cls( + host=host, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + timeout=timeout, + kwargs=kwargs, + ) + + @classmethod + def for_url( + cls, + url: str, + *, + api_key: Optional[str] = None, + prefer_grpc: bool = False, + timeout: Optional[int] = None, + **kwargs: Any, + ) -> "QdrantConnectionParameters": + """Connect using a full URL like 'https://my-qdrant.example.com:6333'.""" + return cls( + url=url, + api_key=api_key, + prefer_grpc=prefer_grpc, + timeout=timeout, + kwargs=kwargs) + + @classmethod + def local(cls, path: str) -> "QdrantConnectionParameters": + """Use an embedded Qdrant instance persisted to the given path.""" + return cls(path=path) + + @classmethod + def in_memory(cls) -> "QdrantConnectionParameters": + """Use an embedded in-memory Qdrant instance. Useful for tests.""" + return cls(location=":memory:") + + +@dataclass +class QdrantWriteConfig(VectorDatabaseWriteConfig): + """Configuration for writing to Qdrant vector database. + + This class defines the parameters needed to write data to a qdrant collection, + including collection targeting, batching behavior, and operation timeouts. + + Args: + connection_params: QdrantConnectionParameters with connection settings. + collection_name: Name of the Qdrant collection to write to. + timeout: Optional timeout for write operations in seconds. Default is None. + batch_size: Number of points to write in each batch. Default is 1000. + kwargs: Additional keyword arguments to pass to the client's upsert method. + dense_embedding_key: name for the dense vector in the qdrant collection. + sparse_embedding_key: name for the sparse vector in the qdrant collection. + """ + + connection_params: QdrantConnectionParameters + collection_name: str + timeout: Optional[int] = None + batch_size: int = DEFAULT_WRITE_BATCH_SIZE + max_batch_byte_size: int = DEFAULT_MAX_BATCH_BYTE_SIZE + kwargs: dict[str, Any] = field(default_factory=dict) + dense_embedding_key: str = "dense" + sparse_embedding_key: str = "sparse" + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + if self.batch_size <= 0: + raise ValueError("Batch size must be a positive integer") + + def create_write_transform(self) -> beam.PTransform[EmbeddableItem, Any]: + return _QdrantWriteTransform(self) + + def create_converter( + self, + ) -> Callable[[EmbeddableItem], "models.PointStruct"]: + def convert(item: EmbeddableItem) -> "models.PointStruct": + if item.dense_embedding is None and item.sparse_embedding is None: + raise ValueError( + "EmbeddableItem must have at least one embedding (dense or sparse)") + vector = {} + if item.dense_embedding is not None: + vector[self.dense_embedding_key] = item.dense_embedding + if item.sparse_embedding is not None: + sparse_indices, sparse_values = item.sparse_embedding + vector[self.sparse_embedding_key] = models.SparseVector( + indices=sparse_indices, + values=sparse_values, + ) + id = ( + int(item.id) + if isinstance(item.id, str) and item.id.isdigit() else item.id) + return models.PointStruct( + id=id, + vector=vector, + payload=item.metadata if item.metadata else None, + ) + + return convert + + +class _QdrantWriteTransform(beam.PTransform): + def __init__(self, config: QdrantWriteConfig): + self.config = config + + def expand(self, input_or_inputs: beam.PCollection[EmbeddableItem]): + return ( + input_or_inputs + | "Convert to Records" >> beam.Map(self.config.create_converter()) + | beam.ParDo(_QdrantWriteFn(self.config))) + + +class _QdrantWriteFn(beam.DoFn): + def __init__(self, config: QdrantWriteConfig): + self.config = config + self._client: "Optional[QdrantClient]" = None + + def start_bundle(self): + self._batch = [] + self._batch_byte_size = 0 + + def process(self, element, *args, **kwargs): + element_byte_size = get_deep_size(element) + new_batch_byte_size = self._batch_byte_size + element_byte_size + + is_batch_full = len(self._batch) >= self.config.batch_size + is_batch_too_large = new_batch_byte_size > self.config.max_batch_byte_size + if (is_batch_full or is_batch_too_large): + self._flush() + self._batch.append(element) + self._batch_byte_size += element_byte_size + + def setup(self): + params = self.config.connection_params + self._client = QdrantClient( + location=params.location, + url=params.url, + port=params.port, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + https=params.https, + api_key=params.api_key, + prefix=params.prefix, + timeout=params.timeout, + host=params.host, + path=params.path, + check_compatibility=False, + **params.kwargs, + ) + + def teardown(self): + if self._client: + self._client.close() + self._client = None + + def finish_bundle(self): + self._flush() + + def _flush(self): + if not self._batch: + return + if not self._client: + raise RuntimeError("Qdrant client is not initialized") + + max_retries = 3 + attempt = 1 + while True: + try: + self._client.upsert( + collection_name=self.config.collection_name, + points=self._batch, + timeout=self.config.timeout, + **self.config.kwargs, + ) + break + except ResourceExhaustedResponse as e: + time.sleep(e.retry_after_s) + # don't count rate-limit against max_retries + continue + except (UnexpectedResponse, ResponseHandlingException, + grpc.RpcError) as e: + if attempt > max_retries: + raise + time.sleep(2**attempt) + attempt += 1 + self._batch = [] + self._batch_byte_size = 0 + + def display_data(self): + res = super().display_data() + res["collection"] = self.config.collection_name + res["batch_size"] = self.config.batch_size + res["max_batch_byte_size"] = self.config.max_batch_byte_size + return res diff --git a/sdks/python/apache_beam/ml/rag/ingestion/qdrant_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/qdrant_it_test.py new file mode 100644 index 000000000000..ea97ec6638e8 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/qdrant_it_test.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.qdrant import QdrantConnectionParameters +from apache_beam.ml.rag.ingestion.qdrant import QdrantWriteConfig +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import EmbeddableItem +from apache_beam.ml.rag.types import Embedding +from apache_beam.testing.test_pipeline import TestPipeline + +# pylint: disable=ungrouped-imports +try: + from qdrant_client import QdrantClient + from qdrant_client import models + QDRANT_AVAILABLE = True +except ImportError: + QDRANT_AVAILABLE = False +# pylint: enable=ungrouped-imports + +TEST_CORPUS = [ + EmbeddableItem( + id="1", + content=Content(text="Test document one"), + metadata={"source": "test1"}, + embedding=Embedding(dense_embedding=[1.0, 0.0]), + ), + EmbeddableItem( + id="2", + content=Content(text="Test document two"), + metadata={"source": "test2"}, + embedding=Embedding(dense_embedding=[0.0, 1.0]), + ), + EmbeddableItem( + id="3", + content=Content(text="Test document three"), + metadata={"source": "test3"}, + embedding=Embedding(dense_embedding=[-1.0, 0.0]), + ), +] + + +@unittest.skipIf(not QDRANT_AVAILABLE, "qdrant dependencies not installed.") +class TestQdrantIngestion(unittest.TestCase): + @contextlib.contextmanager + def qdrant_client(self) -> 'QdrantClient': + client = QdrantClient(path=self._temp_dir.name) + try: + yield client + finally: + client.close() + + def setUp(self): + self._temp_dir = tempfile.TemporaryDirectory() + self._collection_name = f"test_collection_{self._testMethodName}" + + with self.qdrant_client() as client: + client.create_collection( + collection_name=self._collection_name, + vectors_config={ + "dense": models.VectorParams( + size=2, distance=models.Distance.COSINE) + }, + sparse_vectors_config={"sparse": models.SparseVectorParams()}, + ) + assert client.collection_exists(collection_name=self._collection_name) + + self._connection_params = QdrantConnectionParameters( + path=self._temp_dir.name) + + def tearDown(self): + self._temp_dir.cleanup() + + def test_write_on_non_existent_collection(self): + non_existent = "nonexistent_collection" + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=non_existent, + batch_size=1, + ) + + with self.assertRaises(Exception): + with TestPipeline() as p: + _ = p | beam.Create(TEST_CORPUS) | write_config.create_write_transform() + + def test_write_dense_embeddings_only(self): + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=len(TEST_CORPUS), + ) + + with TestPipeline() as p: + _ = p | beam.Create(TEST_CORPUS) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(TEST_CORPUS)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in TEST_CORPUS: + expected_record = models.Record( + id=int(item.id), + vector={"dense": item.dense_embedding}, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_sparse_embeddings_only(self): + sparse_corpus = [ + EmbeddableItem( + id="1", + content=Content(text="Sparse doc one"), + metadata={"source": "sparse1"}, + embedding=Embedding(sparse_embedding=([0, 1, 2], [0.1, 0.2, 0.3])), + ), + EmbeddableItem( + id="2", + content=Content(text="Sparse doc two"), + metadata={"source": "sparse2"}, + embedding=Embedding(sparse_embedding=([1, 3, 5], [0.4, 0.5, 0.6])), + ), + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=len(sparse_corpus), + ) + + with TestPipeline() as p: + _ = p | beam.Create(sparse_corpus) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(sparse_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in sparse_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "sparse": models.SparseVector( + indices=item.sparse_embedding[0], + values=item.sparse_embedding[1], + ) + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_both_dense_and_sparse(self): + hybrid_corpus = [ + EmbeddableItem( + id="1", + content=Content(text="Hybrid doc one"), + metadata={"source": "hybrid1"}, + embedding=Embedding( + dense_embedding=[1.0, 0.0], + sparse_embedding=([0, 1], [0.1, 0.2])), + ), + EmbeddableItem( + id="2", + content=Content(text="Hybrid doc two"), + metadata={"source": "hybrid2"}, + embedding=Embedding( + dense_embedding=[0.0, 1.0], + sparse_embedding=([2, 3], [0.3, 0.4])), + ), + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=len(hybrid_corpus), + ) + + with TestPipeline() as p: + _ = p | beam.Create(hybrid_corpus) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(hybrid_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in hybrid_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "dense": item.dense_embedding, + "sparse": models.SparseVector( + indices=item.sparse_embedding[0], + values=item.sparse_embedding[1]), + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_with_batching(self): + batch_corpus = [ + EmbeddableItem( + id=str(i), + content=Content(text=f"Batch doc {i}"), + metadata={"batch_id": i}, + embedding=Embedding(dense_embedding=[1.0, 0.0]), + ) for i in range(1, 8) + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=3, + ) + + with TestPipeline() as p: + _ = p | beam.Create(batch_corpus) | write_config.create_write_transform() + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(batch_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in batch_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "dense": item.dense_embedding, + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + def test_write_with_byte_size_limit(self): + byte_size_corpus = [ + EmbeddableItem( + id=str(i), + content=Content(text=f"Byte size doc {i}"), + metadata={"data": "x" * 9000}, + embedding=Embedding(dense_embedding=[1.0, 0.0]), + ) for i in range(5) + ] + + write_config = QdrantWriteConfig( + connection_params=self._connection_params, + collection_name=self._collection_name, + batch_size=100, + max_batch_byte_size=15_000, + ) + + with TestPipeline() as p: + _ = ( + p + | beam.Create(byte_size_corpus) + | write_config.create_write_transform()) + + with self.qdrant_client() as client: + count_result = client.count(collection_name=self._collection_name) + self.assertEqual(count_result.count, len(byte_size_corpus)) + + points, _ = client.scroll( + collection_name=self._collection_name, + limit=100, + with_payload=True, + with_vectors=True, + ) + points_by_id = {p.id: p for p in points} + + for item in byte_size_corpus: + expected_record = models.Record( + id=int(item.id), + vector={ + "dense": item.dense_embedding, + }, + payload=item.metadata, + ) + self.assertEqual(expected_record, points_by_id[int(item.id)]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/qdrant_test.py b/sdks/python/apache_beam/ml/rag/ingestion/qdrant_test.py new file mode 100644 index 000000000000..ff4ee14e97a0 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/qdrant_test.py @@ -0,0 +1,480 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +try: + from qdrant_client import models + from qdrant_client.common.client_exceptions import ResourceExhaustedResponse + from qdrant_client.http.exceptions import ResponseHandlingException + from qdrant_client.http.exceptions import UnexpectedResponse + + QDRANT_AVAILABLE = True +except ImportError: + QDRANT_AVAILABLE = False + +import grpc +from objsize import get_deep_size + +from apache_beam.ml.rag.ingestion.qdrant import QdrantConnectionParameters +from apache_beam.ml.rag.ingestion.qdrant import QdrantWriteConfig +from apache_beam.ml.rag.ingestion.qdrant import _QdrantWriteFn +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import EmbeddableItem +from apache_beam.ml.rag.types import Embedding + + +class TestQdrantConnectionParameters(unittest.TestCase): + def test_no_params_raises_value_error(self): + with self.assertRaises(ValueError): + QdrantConnectionParameters() + + def test_location_is_sufficient(self): + QdrantConnectionParameters(location=":memory:") + + def test_url_is_sufficient(self): + QdrantConnectionParameters(url="http://localhost:6333") + + def test_host_is_sufficient(self): + QdrantConnectionParameters(host="localhost") + + def test_path_is_sufficient(self): + QdrantConnectionParameters(path="/tmp/qdrant") + + +class TestQdrantWriteConfig(unittest.TestCase): + def test_empty_collection_name_raises_value_error(self): + with self.assertRaises(ValueError): + QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="", + ) + + def test_none_collection_name_raises_value_error(self): + with self.assertRaises(ValueError): + QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name=None, + ) + + def test_batch_size_zero_raises_value_error(self): + with self.assertRaises(ValueError): + QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + batch_size=0, + ) + + def test_batch_size_negative_raises_value_error(self): + with self.assertRaises(ValueError): + QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + batch_size=-1, + ) + + def test_display_data(self): + config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + batch_size=100, + max_batch_byte_size=5000, + ) + fn = _QdrantWriteFn(config) + data = fn.display_data() + self.assertEqual(data["collection"], "test") + self.assertEqual(data["batch_size"], 100) + self.assertEqual(data["max_batch_byte_size"], 5000) + + def test_for_cloud_creates_connection(self): + params = QdrantConnectionParameters.for_cloud( + url="https://test.cloud.qdrant.io", + api_key="my-key", + ) + self.assertEqual(params.url, "https://test.cloud.qdrant.io") + self.assertEqual(params.api_key, "my-key") + self.assertTrue(params.https) + + def test_for_host_creates_connection(self): + params = QdrantConnectionParameters.for_host(host="localhost", port=6333) + self.assertEqual(params.host, "localhost") + self.assertEqual(params.port, 6333) + + def test_in_memory_creates_connection(self): + params = QdrantConnectionParameters.in_memory() + self.assertEqual(params.location, ":memory:") + + def test_for_url_creates_connection(self): + params = QdrantConnectionParameters.for_url(url="http://localhost:6333") + self.assertEqual(params.url, "http://localhost:6333") + + def test_kwargs_passthrough(self): + config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + kwargs={"parallel": 4}, + ) + self.assertEqual(config.kwargs, {"parallel": 4}) + + +@unittest.skipIf(not QDRANT_AVAILABLE, "qdrant dependencies not installed.") +class TestQdrantCreateConverter(unittest.TestCase): + def setUp(self): + self.config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + ) + self.convert = self.config.create_converter() + + def test_dense_embedding_only(self): + item = EmbeddableItem( + id="1", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[1.0, 2.0]), + ) + result = self.convert(item) + self.assertIsInstance(result, models.PointStruct) + self.assertEqual(result.id, 1) + self.assertEqual(result.vector, {"dense": [1.0, 2.0]}) + self.assertIsNone(result.payload) + + def test_sparse_embedding_only(self): + item = EmbeddableItem( + id="2", + content=Content(text="test"), + embedding=Embedding(sparse_embedding=([0, 1], [0.5, 0.3])), + ) + result = self.convert(item) + self.assertIsInstance(result, models.PointStruct) + self.assertIn("sparse", result.vector) + sparse_vec = result.vector["sparse"] + self.assertIsInstance(sparse_vec, models.SparseVector) + self.assertEqual(sparse_vec.indices, [0, 1]) + self.assertEqual(sparse_vec.values, [0.5, 0.3]) + + def test_both_dense_and_sparse(self): + item = EmbeddableItem( + id="3", + content=Content(text="test"), + embedding=Embedding( + dense_embedding=[1.0, 2.0], + sparse_embedding=([0], [0.9]), + ), + ) + result = self.convert(item) + self.assertEqual(set(result.vector.keys()), {"dense", "sparse"}) + self.assertEqual(result.vector["dense"], [1.0, 2.0]) + self.assertEqual(result.id, 3) + + def test_raises_when_no_embedding(self): + item = EmbeddableItem( + id="4", + content=Content(text="test"), + ) + with self.assertRaises(ValueError): + self.convert(item) + + def test_string_digit_id_converted_to_int(self): + item = EmbeddableItem( + id="42", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[0.1, 0.2]), + ) + result = self.convert(item) + self.assertEqual(result.id, 42) + self.assertIsInstance(result.id, int) + + def test_non_digit_string_id_preserved(self): + item = EmbeddableItem( + id="abc-123", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[0.1, 0.2]), + ) + result = self.convert(item) + self.assertEqual(result.id, "abc-123") + self.assertIsInstance(result.id, str) + + def test_integer_id_preserved(self): + item = EmbeddableItem( + id="99", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[0.1, 0.2]), + ) + result = self.convert(item) + self.assertEqual(result.id, 99) + self.assertIsInstance(result.id, int) + + def test_none_metadata_becomes_none_payload(self): + item = EmbeddableItem( + id="1", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[0.1, 0.2]), + metadata={}, + ) + result = self.convert(item) + self.assertIsNone(result.payload) + + def test_custom_vector_keys(self): + config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + dense_embedding_key="my_dense", + sparse_embedding_key="my_sparse", + ) + convert = config.create_converter() + item = EmbeddableItem( + id="1", + content=Content(text="test"), + embedding=Embedding( + dense_embedding=[1.0], + sparse_embedding=([0], [0.5]), + ), + ) + result = convert(item) + self.assertIn("my_dense", result.vector) + self.assertIn("my_sparse", result.vector) + self.assertNotIn("dense", result.vector) + self.assertNotIn("sparse", result.vector) + + def test_payload_includes_metadata(self): + item = EmbeddableItem( + id="1", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[1.0]), + metadata={ + "source": "test", "score": 0.95 + }, + ) + result = self.convert(item) + self.assertEqual(result.payload, {"source": "test", "score": 0.95}) + + def test_convert_from_text_factory(self): + item = EmbeddableItem.from_text("hello", metadata={"source": "test"}) + item.embedding = Embedding(dense_embedding=[0.5, 0.5]) + result = self.convert(item) + self.assertIsInstance(result, models.PointStruct) + self.assertIn("dense", result.vector) + + +@unittest.skipIf(not QDRANT_AVAILABLE, "qdrant dependencies not installed.") +class TestQdrantWriteFnBatching(unittest.TestCase): + def setUp(self): + self.config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + batch_size=3, + ) + self.fn = _QdrantWriteFn(self.config) + self.fn._client = mock.MagicMock() + self.fn.start_bundle() + + def test_batch_size_triggers_flush_correctly(self): + client = self.fn._client + for i in range(5): + self.fn.process( + EmbeddableItem( + id=str(i), + content=Content(text="test"), + embedding=Embedding(dense_embedding=[float(i)]), + )) + self.fn.finish_bundle() + + self.assertEqual(client.upsert.call_count, 2) + first = client.upsert.call_args_list[0][1]["points"] + second = client.upsert.call_args_list[1][1]["points"] + self.assertEqual(len(first), 3) + self.assertEqual(len(second), 2) + self.assertEqual(first[0].id, "0") + self.assertEqual(first[1].id, "1") + self.assertEqual(first[2].id, "2") + self.assertEqual(second[0].id, "3") + self.assertEqual(second[1].id, "4") + + def test_partial_batch_flushed_on_finish_bundle(self): + for i in range(2): + self.fn.process( + EmbeddableItem( + id=str(i), + content=Content(text="test"), + embedding=Embedding(dense_embedding=[float(i)]), + )) + self.fn.finish_bundle() + + points = self.fn._client.upsert.call_args[1]["points"] + self.assertEqual(len(points), 2) + + def test_byte_size_exceeded_triggers_flush(self): + item = EmbeddableItem( + id="1", + content=Content( + text="a" * 256, + image=b"x" * 1024, + ), + ) + item_size = get_deep_size(item) + + config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + batch_size=10, + max_batch_byte_size=item_size * 2, + ) + fn = _QdrantWriteFn(config) + fn._client = mock.MagicMock() + fn.start_bundle() + client = fn._client + + for i in range(3): + fn.process( + EmbeddableItem( + id=str(i), + content=Content( + text="a" * 256, + image=b"x" * 1024, + ), + )) + fn.finish_bundle() + + self.assertEqual(client.upsert.call_count, 2) + first = client.upsert.call_args_list[0][1]["points"] + second = client.upsert.call_args_list[1][1]["points"] + self.assertEqual(len(first), 2) + self.assertEqual(len(second), 1) + + +@unittest.skipIf(not QDRANT_AVAILABLE, "qdrant dependencies not installed.") +class TestQdrantWriteFnRetries(unittest.TestCase): + def setUp(self): + self.config = QdrantWriteConfig( + connection_params=QdrantConnectionParameters(location=":memory:"), + collection_name="test", + ) + self.fn = _QdrantWriteFn(self.config) + self.fn._client = mock.MagicMock() + self.fn._batch = [ + EmbeddableItem( + id="1", + content=Content(text="test"), + embedding=Embedding(dense_embedding=[1.0]), + ) + ] + self.fn._batch_byte_size = 100 + + def test_retry_on_unexpected_response(self): + self.fn._client.upsert.side_effect = [ + UnexpectedResponse(429, "error", b"", None), + None, + ] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 2) + mock_sleep.assert_called_once_with(2) + + def test_retry_on_response_handling_exception(self): + self.fn._client.upsert.side_effect = [ + ResponseHandlingException(Exception("error")), + None, + ] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 2) + mock_sleep.assert_called_once_with(2) + + def test_retry_on_grpc_error(self): + self.fn._client.upsert.side_effect = [ + grpc.RpcError("error"), + None, + ] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 2) + mock_sleep.assert_called_once_with(2) + + def test_rate_limit_does_not_increment_attempt(self): + exc = ResourceExhaustedResponse("rate limited", 0) + exc.retry_after_s = 0.01 + self.fn._client.upsert.side_effect = [exc, None] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 2) + mock_sleep.assert_called_once_with(0.01) + + def test_multiple_rate_limits_dont_exhaust_retries(self): + exc = ResourceExhaustedResponse("rate limited", 0) + exc.retry_after_s = 0.01 + self.fn._client.upsert.side_effect = [exc, exc, exc, None] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 4) + self.assertEqual(mock_sleep.call_count, 3) + + def test_rate_limit_then_error_then_success(self): + exc_rate = ResourceExhaustedResponse("rate limited", 0) + exc_rate.retry_after_s = 0.01 + exc_error = UnexpectedResponse(429, "error", b"", None) + self.fn._client.upsert.side_effect = [exc_error, exc_rate, None] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 3) + self.assertEqual(mock_sleep.call_args_list[0], mock.call(2)) + self.assertEqual(mock_sleep.call_args_list[1], mock.call(0.01)) + + def test_exponential_backoff_values(self): + self.fn._client.upsert.side_effect = [ + UnexpectedResponse(429, "e1", b"", None), + UnexpectedResponse(429, "e2", b"", None), + UnexpectedResponse(429, "e3", b"", None), + None, + ] + with mock.patch("time.sleep") as mock_sleep: + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 4) + self.assertEqual(mock_sleep.call_args_list[0], mock.call(2)) + self.assertEqual(mock_sleep.call_args_list[1], mock.call(4)) + self.assertEqual(mock_sleep.call_args_list[2], mock.call(8)) + + def test_raises_after_max_retries(self): + self.fn._client.upsert.side_effect = [ + UnexpectedResponse(429, "e1", b"", None), + UnexpectedResponse(429, "e2", b"", None), + UnexpectedResponse(429, "e3", b"", None), + UnexpectedResponse(429, "e4", b"", None), + ] + with mock.patch("time.sleep") as mock_sleep: + with self.assertRaises(UnexpectedResponse): + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 4) + self.assertEqual(mock_sleep.call_count, 3) + + def test_raises_on_last_non_rate_limit_attempt(self): + exc_rate = ResourceExhaustedResponse("rate limited", 0) + exc_rate.retry_after_s = 0.01 + self.fn._client.upsert.side_effect = [ + exc_rate, + UnexpectedResponse(429, "e1", b"", None), + UnexpectedResponse(429, "e2", b"", None), + UnexpectedResponse(429, "e3", b"", None), + UnexpectedResponse(429, "e4", b"", None), + ] + with mock.patch("time.sleep") as mock_sleep: + with self.assertRaises(UnexpectedResponse): + self.fn._flush() + self.assertEqual(self.fn._client.upsert.call_count, 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 45781a44c4b1..d094dbf2c82b 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -166,6 +166,7 @@ def cythonize(*args, **kwargs): ] milvus_dependency = ['pymilvus>=2.5.10,<3.0.0'] +qdrant_dependency = ['qdrant-client>=1.15.0'] # google-adk / OpenTelemetry require protobuf>=5; tensorflow-transform in # ml_test is pinned to versions that require protobuf<5 on Python 3.10. Those @@ -606,14 +607,14 @@ def get_portability_package_data(): 'tf2onnx>=1.16.1,<1.17', ] + ml_base_core, 'p310_ml_test': [ - 'datatable', - ] + ml_base, + 'datatable', + ] + ml_base + qdrant_dependency, 'p312_ml_test': [ 'datatable', - ] + ml_base, + ] + ml_base + qdrant_dependency, # maintainer: milvus tests only run with this extension. Make sure it # is covered by docker-in-docker test when changing py version - 'p313_ml_test': ml_base + milvus_dependency, + 'p313_ml_test': ml_base + milvus_dependency + qdrant_dependency, 'aws': ['boto3>=1.9,<2'], 'azure': [ 'azure-storage-blob>=12.3.2,<13', @@ -684,6 +685,7 @@ def get_portability_package_data(): 'xgboost': ['xgboost>=1.6.0,<2.1.3', 'datatable==1.0.0'], 'tensorflow-hub': ['tensorflow-hub>=0.14.0,<0.16.0'], 'milvus': milvus_dependency, + 'qdrant': qdrant_dependency, 'vllm': ['openai==1.107.1', 'vllm==0.10.1.1', 'triton==3.3.1'] }, zip_safe=False,