Skip to content

Commit e8d0148

Browse files
committed
repurpose elasticsearch dense_vector field for knn_vector
Signed-off-by: Edward Auttonberry <[email protected]>
1 parent 2f290be commit e8d0148

File tree

4 files changed

+121
-6
lines changed

4 files changed

+121
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
66
### Updated APIs
77
- Updated opensearch-py APIs to reflect [opensearch-api-specification@578a78d](https://github.com/opensearch-project/opensearch-api-specification/commit/578a78dcec746e81da88f81ad442ab1836db7694)
88
### Changed
9+
- Rename `DenseVector` field type to `KnnVector` ([924](https://github.com/opensearch-project/opensearch-py/pull/924))
910
### Deprecated
1011
### Removed
1112
### Fixed

opensearchpy/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@
9292
CustomField,
9393
Date,
9494
DateRange,
95-
DenseVector,
9695
Double,
9796
DoubleRange,
9897
Field,
@@ -107,6 +106,7 @@
107106
IpRange,
108107
Join,
109108
Keyword,
109+
KnnVector,
110110
Long,
111111
LongRange,
112112
Murmur3,
@@ -178,7 +178,7 @@
178178
"Date",
179179
"DateHistogramFacet",
180180
"DateRange",
181-
"DenseVector",
181+
"KnnVector",
182182
"Document",
183183
"Double",
184184
"DoubleRange",

opensearchpy/helpers/field.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,12 +354,12 @@ def _deserialize(self, data: Any) -> Any:
354354
return float(data)
355355

356356

357-
class DenseVector(Float):
358-
name: Optional[str] = "dense_vector"
357+
class KnnVector(Float):
358+
name: Optional[str] = "knn_vector"
359359

360-
def __init__(self, dims: Any, **kwargs: Any) -> None:
360+
def __init__(self, dimension: Any, **kwargs: Any) -> None:
361361
kwargs["multi"] = True
362-
super().__init__(dims=dims, **kwargs)
362+
super().__init__(dimension=dimension, **kwargs)
363363

364364

365365
class SparseVector(Field):

test_opensearchpy/test_helpers/test_field.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434

3535
from opensearchpy import InnerDoc, Range, ValidationException
3636
from opensearchpy.helpers import field
37+
from opensearchpy.helpers.index import Index
38+
from opensearchpy.helpers.mapping import Mapping
39+
from opensearchpy.helpers.test import OpenSearchTestCase
3740

3841

3942
def test_date_range_deserialization() -> None:
@@ -221,3 +224,114 @@ class Inner(InnerDoc):
221224

222225
with pytest.raises(ValidationException):
223226
field.Object(doc_class=Inner, dynamic=False)
227+
228+
229+
def test_knn_vector() -> None:
230+
f = field.KnnVector(dimension=128)
231+
assert f.to_dict() == {"type": "knn_vector", "dimension": 128}
232+
233+
# Test that dimension parameter is required
234+
with pytest.raises(TypeError):
235+
field.KnnVector() # type: ignore
236+
237+
assert f._multi is True
238+
239+
240+
def test_knn_vector_with_additional_params() -> None:
241+
f = field.KnnVector(
242+
dimension=256, method={"name": "hnsw", "space_type": "l2", "engine": "faiss"}
243+
)
244+
expected = {
245+
"type": "knn_vector",
246+
"dimension": 256,
247+
"method": {"name": "hnsw", "space_type": "l2", "engine": "faiss"},
248+
}
249+
assert f.to_dict() == expected
250+
251+
252+
def test_knn_vector_serialization() -> None:
253+
f = field.KnnVector(dimension=3)
254+
255+
vector_data = [1.0, 2.0, 3.0]
256+
serialized = f.serialize(vector_data)
257+
assert serialized == vector_data
258+
259+
assert f.serialize(None) is None
260+
261+
262+
def test_knn_vector_deserialization() -> None:
263+
f = field.KnnVector(dimension=3)
264+
265+
vector_data = [1.0, 2.0, 3.0]
266+
deserialized = f.deserialize(vector_data)
267+
assert deserialized == vector_data
268+
269+
assert f.deserialize(None) is None
270+
271+
272+
def test_knn_vector_construct_from_dict() -> None:
273+
f = field.construct_field({"type": "knn_vector", "dimension": 128})
274+
275+
assert isinstance(f, field.KnnVector)
276+
assert f.to_dict() == {"type": "knn_vector", "dimension": 128}
277+
278+
279+
def test_knn_vector_construct_from_dict_with_method() -> None:
280+
f = field.construct_field(
281+
{
282+
"type": "knn_vector",
283+
"dimension": 256,
284+
"method": {"name": "hnsw", "space_type": "cosinesimil", "engine": "lucene"},
285+
}
286+
)
287+
288+
assert isinstance(f, field.KnnVector)
289+
expected = {
290+
"type": "knn_vector",
291+
"dimension": 256,
292+
"method": {"name": "hnsw", "space_type": "cosinesimil", "engine": "lucene"},
293+
}
294+
assert f.to_dict() == expected
295+
296+
297+
class TestKnnVectorIntegration(OpenSearchTestCase):
298+
def test_index_and_retrieve_knn_vector(self):
299+
index_name = "itest-knn-vector"
300+
# ensure clean state
301+
self.client.indices.delete(index=index_name, ignore=404)
302+
303+
# Create index using DSL abstractions
304+
idx = Index(index_name, using=self.client)
305+
idx.settings(**{"index.knn": True})
306+
307+
mapping = Mapping()
308+
mapping.field("vec", field.KnnVector(dimension=3))
309+
idx.mapping(mapping)
310+
311+
result = idx.create()
312+
assert result["acknowledged"] is True
313+
314+
field_mapping = idx.get_field_mapping(fields="vec")
315+
assert field_mapping[index_name]["mappings"]["vec"]["mapping"]["vec"] == {
316+
"type": "knn_vector",
317+
"dimension": 3,
318+
}
319+
320+
# search tests
321+
doc = {"vec": [1.0, 2.0, 3.0]}
322+
result = self.client.index(index=index_name, id=1, body=doc, refresh=True)
323+
assert result["_shards"]["successful"] == 1
324+
get_resp = self.client.get(index=index_name, id=1)
325+
assert get_resp["_source"]["vec"] == doc["vec"]
326+
327+
search_body = {
328+
"size": 1,
329+
"query": {"knn": {"vec": {"vector": [1.0, 2.0, 3.0], "k": 1}}},
330+
}
331+
search_resp = self.client.search(index=index_name, body=search_body)
332+
hits = search_resp["hits"]["hits"]
333+
assert len(hits) == 1
334+
assert hits[0]["_id"] == "1"
335+
336+
# cleanup
337+
self.client.indices.delete(index=index_name)

0 commit comments

Comments
 (0)