|  | 
| 34 | 34 | 
 | 
| 35 | 35 | from opensearchpy import InnerDoc, Range, ValidationException | 
| 36 | 36 | from opensearchpy.helpers import field | 
|  | 37 | +from opensearchpy.helpers.index import Index | 
|  | 38 | +from opensearchpy.helpers.mapping import Mapping | 
|  | 39 | +from opensearchpy.helpers.test import OpenSearchTestCase | 
| 37 | 40 | 
 | 
| 38 | 41 | 
 | 
| 39 | 42 | def test_date_range_deserialization() -> None: | 
| @@ -221,3 +224,114 @@ class Inner(InnerDoc): | 
| 221 | 224 | 
 | 
| 222 | 225 |     with pytest.raises(ValidationException): | 
| 223 | 226 |         field.Object(doc_class=Inner, dynamic=False) | 
|  | 227 | + | 
|  | 228 | + | 
|  | 229 | +def test_knn_vector() -> None: | 
|  | 230 | +    f = field.KnnVector(dimension=128) | 
|  | 231 | +    assert f.to_dict() == {"type": "knn_vector", "dimension": 128} | 
|  | 232 | + | 
|  | 233 | +    # Test that dimension parameter is required | 
|  | 234 | +    with pytest.raises(TypeError): | 
|  | 235 | +        field.KnnVector()  # type: ignore | 
|  | 236 | + | 
|  | 237 | +    assert f._multi is True | 
|  | 238 | + | 
|  | 239 | + | 
|  | 240 | +def test_knn_vector_with_additional_params() -> None: | 
|  | 241 | +    f = field.KnnVector( | 
|  | 242 | +        dimension=256, method={"name": "hnsw", "space_type": "l2", "engine": "faiss"} | 
|  | 243 | +    ) | 
|  | 244 | +    expected = { | 
|  | 245 | +        "type": "knn_vector", | 
|  | 246 | +        "dimension": 256, | 
|  | 247 | +        "method": {"name": "hnsw", "space_type": "l2", "engine": "faiss"}, | 
|  | 248 | +    } | 
|  | 249 | +    assert f.to_dict() == expected | 
|  | 250 | + | 
|  | 251 | + | 
|  | 252 | +def test_knn_vector_serialization() -> None: | 
|  | 253 | +    f = field.KnnVector(dimension=3) | 
|  | 254 | + | 
|  | 255 | +    vector_data = [1.0, 2.0, 3.0] | 
|  | 256 | +    serialized = f.serialize(vector_data) | 
|  | 257 | +    assert serialized == vector_data | 
|  | 258 | + | 
|  | 259 | +    assert f.serialize(None) is None | 
|  | 260 | + | 
|  | 261 | + | 
|  | 262 | +def test_knn_vector_deserialization() -> None: | 
|  | 263 | +    f = field.KnnVector(dimension=3) | 
|  | 264 | + | 
|  | 265 | +    vector_data = [1.0, 2.0, 3.0] | 
|  | 266 | +    deserialized = f.deserialize(vector_data) | 
|  | 267 | +    assert deserialized == vector_data | 
|  | 268 | + | 
|  | 269 | +    assert f.deserialize(None) is None | 
|  | 270 | + | 
|  | 271 | + | 
|  | 272 | +def test_knn_vector_construct_from_dict() -> None: | 
|  | 273 | +    f = field.construct_field({"type": "knn_vector", "dimension": 128}) | 
|  | 274 | + | 
|  | 275 | +    assert isinstance(f, field.KnnVector) | 
|  | 276 | +    assert f.to_dict() == {"type": "knn_vector", "dimension": 128} | 
|  | 277 | + | 
|  | 278 | + | 
|  | 279 | +def test_knn_vector_construct_from_dict_with_method() -> None: | 
|  | 280 | +    f = field.construct_field( | 
|  | 281 | +        { | 
|  | 282 | +            "type": "knn_vector", | 
|  | 283 | +            "dimension": 256, | 
|  | 284 | +            "method": {"name": "hnsw", "space_type": "cosinesimil", "engine": "lucene"}, | 
|  | 285 | +        } | 
|  | 286 | +    ) | 
|  | 287 | + | 
|  | 288 | +    assert isinstance(f, field.KnnVector) | 
|  | 289 | +    expected = { | 
|  | 290 | +        "type": "knn_vector", | 
|  | 291 | +        "dimension": 256, | 
|  | 292 | +        "method": {"name": "hnsw", "space_type": "cosinesimil", "engine": "lucene"}, | 
|  | 293 | +    } | 
|  | 294 | +    assert f.to_dict() == expected | 
|  | 295 | + | 
|  | 296 | + | 
|  | 297 | +class TestKnnVectorIntegration(OpenSearchTestCase): | 
|  | 298 | +    def test_index_and_retrieve_knn_vector(self): | 
|  | 299 | +        index_name = "itest-knn-vector" | 
|  | 300 | +        # ensure clean state | 
|  | 301 | +        self.client.indices.delete(index=index_name, ignore=404) | 
|  | 302 | + | 
|  | 303 | +        # Create index using DSL abstractions | 
|  | 304 | +        idx = Index(index_name, using=self.client) | 
|  | 305 | +        idx.settings(**{"index.knn": True}) | 
|  | 306 | + | 
|  | 307 | +        mapping = Mapping() | 
|  | 308 | +        mapping.field("vec", field.KnnVector(dimension=3)) | 
|  | 309 | +        idx.mapping(mapping) | 
|  | 310 | + | 
|  | 311 | +        result = idx.create() | 
|  | 312 | +        assert result["acknowledged"] is True | 
|  | 313 | + | 
|  | 314 | +        field_mapping = idx.get_field_mapping(fields="vec") | 
|  | 315 | +        assert field_mapping[index_name]["mappings"]["vec"]["mapping"]["vec"] == { | 
|  | 316 | +            "type": "knn_vector", | 
|  | 317 | +            "dimension": 3, | 
|  | 318 | +        } | 
|  | 319 | + | 
|  | 320 | +        # search tests | 
|  | 321 | +        doc = {"vec": [1.0, 2.0, 3.0]} | 
|  | 322 | +        result = self.client.index(index=index_name, id=1, body=doc, refresh=True) | 
|  | 323 | +        assert result["_shards"]["successful"] == 1 | 
|  | 324 | +        get_resp = self.client.get(index=index_name, id=1) | 
|  | 325 | +        assert get_resp["_source"]["vec"] == doc["vec"] | 
|  | 326 | + | 
|  | 327 | +        search_body = { | 
|  | 328 | +            "size": 1, | 
|  | 329 | +            "query": {"knn": {"vec": {"vector": [1.0, 2.0, 3.0], "k": 1}}}, | 
|  | 330 | +        } | 
|  | 331 | +        search_resp = self.client.search(index=index_name, body=search_body) | 
|  | 332 | +        hits = search_resp["hits"]["hits"] | 
|  | 333 | +        assert len(hits) == 1 | 
|  | 334 | +        assert hits[0]["_id"] == "1" | 
|  | 335 | + | 
|  | 336 | +        # cleanup | 
|  | 337 | +        self.client.indices.delete(index=index_name) | 
0 commit comments