Skip to content

Commit 84740e9

Browse files
committed
Merge branch '30-more-distance-funcs-with-c-0.21.1-alpha1' into 'dev'
More distance functions with c-0.21.1-alpha1 See merge request objectbox/objectbox-python!24
2 parents 15b3266 + 14fbc80 commit 84740e9

File tree

9 files changed

+119
-57
lines changed

9 files changed

+119
-57
lines changed

download-c-lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Script used to download objectbox-c shared libraries for all supported platforms. Execute by running `make get-lib`
77
# on first checkout of this repo and any time after changing the objectbox-c lib version.
88

9-
version = "v0.21.1-alpha0" # see objectbox/c.py required_version
9+
version = "v0.21.1-alpha1" # see objectbox/c.py required_version
1010
variant = 'objectbox' # or 'objectbox-sync'
1111

1212
base_url = "https://github.com/objectbox/objectbox-c/releases/download/"

objectbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
]
3232

3333
# Python binding version
34-
version = Version(0, 7, 0, alpha=6)
34+
version = Version(0, 7, 0, alpha=7)
3535

3636

3737
def version_info():

objectbox/c.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,3 +877,6 @@ def c_array_pointer(py_list: Union[List[Any], np.ndarray], c_type):
877877

878878
OBXHnswDistanceType_UNKNOWN = 0
879879
OBXHnswDistanceType_EUCLIDEAN = 1
880+
OBXHnswDistanceType_COSINE = 2
881+
OBXHnswDistanceType_DOT_PRODUCT = 3
882+
OBXHnswDistanceType_DOT_PRODUCT_NON_NORMALIZED = 10

objectbox/model/properties.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import numpy as np
2121
from dataclasses import dataclass
2222

23-
2423
class PropertyType(IntEnum):
2524
bool = OBXPropertyType_Bool
2625
byte = OBXPropertyType_Byte
@@ -95,9 +94,32 @@ class HnswFlags(IntEnum):
9594

9695

9796
class HnswDistanceType(IntEnum):
98-
UNKNOWN = OBXHnswDistanceType_UNKNOWN,
97+
UNKNOWN = OBXHnswDistanceType_UNKNOWN
9998
EUCLIDEAN = OBXHnswDistanceType_EUCLIDEAN
100-
99+
COSINE = OBXHnswDistanceType_COSINE
100+
DOT_PRODUCT = OBXHnswDistanceType_DOT_PRODUCT
101+
DOT_PRODUCT_NON_NORMALIZED = OBXHnswDistanceType_DOT_PRODUCT_NON_NORMALIZED
102+
103+
HnswDistanceType.UNKNOWN.__doc__ = "Not a real type, just best practice (e.g. forward compatibility)"
104+
HnswDistanceType.EUCLIDEAN.__doc__ = "The default; typically 'euclidean squared' internally."
105+
HnswDistanceType.COSINE.__doc__ = """
106+
Cosine similarity compares two vectors irrespective of their magnitude (compares the angle of two vectors).
107+
Often used for document or semantic similarity.
108+
Value range: 0.0 - 2.0 (0.0: same direction, 1.0: orthogonal, 2.0: opposite direction)
109+
"""
110+
HnswDistanceType.DOT_PRODUCT.__doc__ = """
111+
For normalized vectors (vector length == 1.0), the dot product is equivalent to the cosine similarity.
112+
Because of this, the dot product is often preferred as it performs better.
113+
Value range (normalized vectors): 0.0 - 2.0 (0.0: same direction, 1.0: orthogonal, 2.0: opposite direction)
114+
"""
115+
HnswDistanceType.DOT_PRODUCT_NON_NORMALIZED.__doc__ = """
116+
A custom dot product similarity measure that does not require the vectors to be normalized.
117+
Note: this is no replacement for cosine similarity (like DotProduct for normalized vectors is).
118+
The non-linear conversion provides a high precision over the entire float range (for the raw dot product).
119+
The higher the dot product, the lower the distance is (the nearer the vectors are).
120+
The more negative the dot product, the higher the distance is (the farther the vectors are).
121+
Value range: 0.0 - 2.0 (nonlinear; 0.0: nearest, 1.0: orthogonal, 2.0: farthest)
122+
"""
101123

102124
@dataclass
103125
class HnswIndex:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
'numpy'
5454
],
5555

56-
packages=setuptools.find_packages(),
56+
packages=setuptools.find_packages(exclude=['exampl*']),
5757
package_data={
5858
'objectbox': [
5959
# Linux, macOS

tests/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def create_test_objectbox(db_name: Optional[str] = None, clear_db: bool = True)
5252
model.entity(TestEntity, last_property_id=IdUid(27, 1027))
5353
model.entity(TestEntityDatetime, last_property_id=IdUid(4, 2004))
5454
model.entity(TestEntityFlex, last_property_id=IdUid(2, 3002))
55-
model.entity(VectorEntity, last_property_id=IdUid(3, 4003))
55+
model.entity(VectorEntity, last_property_id=IdUid(5, 4005))
5656
model.last_entity_id = IdUid(4, 4)
57-
model.last_index_id = IdUid(3, 40001)
57+
model.last_index_id = IdUid(5, 40003)
5858

5959
return objectbox.Builder().model(model).directory(db_path).build()
6060

tests/model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,24 @@ class TestEntityFlex:
5757
class VectorEntity:
5858
id = Id(id=1, uid=4001)
5959
name = Property(str, type=PropertyType.string, id=2, uid=4002)
60-
vector = Property(np.ndarray, type=PropertyType.floatVector, id=3, uid=4003,
60+
vector_euclidean = Property(np.ndarray, type=PropertyType.floatVector, id=3, uid=4003,
6161
index=HnswIndex(
6262
id=3, uid=40001,
6363
dimensions=2, distance_type=HnswDistanceType.EUCLIDEAN)
6464
)
65+
vector_cosine = Property(np.ndarray, type=PropertyType.floatVector, id=4, uid=4004,
66+
index=HnswIndex(
67+
id=4, uid=40002,
68+
dimensions=2, distance_type=HnswDistanceType.COSINE)
69+
)
70+
vector_dot_product = Property(np.ndarray, type=PropertyType.floatVector, id=5, uid=4005,
71+
index=HnswIndex(
72+
id=5, uid=40003,
73+
dimensions=2, distance_type=HnswDistanceType.DOT_PRODUCT)
74+
)
75+
#vector_dot_product_non_normalized = Property(np.ndarray, type=PropertyType.floatVector, id=6, uid=4006,
76+
# index=HnswIndex(
77+
# id=6, uid=40004,
78+
# dimensions=2, distance_type=HnswDistanceType.DOT_PRODUCT_NON_NORMALIZED)
79+
# )
80+

tests/test_hnsw.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ def _find_expected_nn(points: np.ndarray, query: np.ndarray, n: int):
1515
return np.argsort(d)[:n]
1616

1717

18-
def _test_random_points(num_points: int, num_query_points: int, seed: Optional[int] = None):
18+
def _test_random_points(num_points: int, num_query_points: int, seed: Optional[int] = None, distance_type: HnswDistanceType = HnswDistanceType.EUCLIDEAN, min_score: float = 0.5):
1919
""" Generates random points in a 2d plane; checks the queried NN against the expected. """
2020

21+
vector_field_name = "vector_"+distance_type.name.lower()
22+
2123
print(f"Test random points; Points: {num_points}, Query points: {num_query_points}, Seed: {seed}")
2224

2325
k = 10
@@ -37,7 +39,7 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
3739
for i in range(points.shape[0]):
3840
object_ = VectorEntity()
3941
object_.name = f"point_{i}"
40-
object_.vector = points[i]
42+
setattr(object_, vector_field_name, points[i])
4143
objects.append(object_)
4244
box.put(*objects)
4345
print(f"DB seeded with {box.count()} random points!")
@@ -58,50 +60,62 @@ def _test_random_points(num_points: int, num_query_points: int, seed: Optional[i
5860

5961
# Run ANN with OBX
6062
qb = box.query()
61-
qb.nearest_neighbors_f32("vector", query_point, k)
63+
qb.nearest_neighbors_f32(vector_field_name, query_point, k)
6264
query = qb.build()
6365
obx_result = [id_ for id_, score in query.find_ids_with_scores()] # Ignore score
6466
assert len(obx_result) == k
6567

6668
# We would like at least half of the expected results, to be returned by the search (in any order)
6769
# Remember: it's an approximate search!
6870
search_score = len(np.intersect1d(expected_result, obx_result)) / k
69-
assert search_score >= 0.5 # TODO likely could be increased
71+
assert search_score >= min_score # TODO likely could be increased
7072

7173
print(f"Done!")
7274

7375

7476
def test_random_points():
75-
_test_random_points(num_points=100, num_query_points=10, seed=10)
76-
_test_random_points(num_points=100, num_query_points=10, seed=11)
77-
_test_random_points(num_points=100, num_query_points=10, seed=12)
78-
_test_random_points(num_points=100, num_query_points=10, seed=13)
79-
_test_random_points(num_points=100, num_query_points=10, seed=14)
80-
_test_random_points(num_points=100, num_query_points=10, seed=15)
81-
82-
83-
def test_combined_nn_search():
84-
""" Tests NN search combined with regular query conditions, offset and limit. """
85-
77+
78+
min_score = 0.5
79+
distance_type = HnswDistanceType.EUCLIDEAN
80+
_test_random_points(num_points=100, num_query_points=10, seed=10, distance_type=distance_type, min_score=min_score)
81+
_test_random_points(num_points=100, num_query_points=10, seed=11, distance_type=distance_type, min_score=min_score)
82+
_test_random_points(num_points=100, num_query_points=10, seed=12, distance_type=distance_type, min_score=min_score)
83+
_test_random_points(num_points=100, num_query_points=10, seed=13, distance_type=distance_type, min_score=min_score)
84+
_test_random_points(num_points=100, num_query_points=10, seed=14, distance_type=distance_type, min_score=min_score)
85+
_test_random_points(num_points=100, num_query_points=10, seed=15, distance_type=distance_type, min_score=min_score)
86+
87+
# TODO: Cosine and Dot Product may result in 0 score
88+
89+
def _test_combined_nn_search(distance_type: HnswDistanceType = HnswDistanceType.EUCLIDEAN):
90+
8691
db = create_test_objectbox()
8792

8893
box = objectbox.Box(db, VectorEntity)
8994

90-
box.put(VectorEntity(name="Power of red", vector=[1, 1]))
91-
box.put(VectorEntity(name="Blueberry", vector=[2, 2]))
92-
box.put(VectorEntity(name="Red", vector=[3, 3]))
93-
box.put(VectorEntity(name="Blue sea", vector=[4, 4]))
94-
box.put(VectorEntity(name="Lightblue", vector=[5, 5]))
95-
box.put(VectorEntity(name="Red apple", vector=[6, 6]))
96-
box.put(VectorEntity(name="Hundred", vector=[7, 7]))
97-
box.put(VectorEntity(name="Tired", vector=[8, 8]))
98-
box.put(VectorEntity(name="Power of blue", vector=[9, 9]))
99-
95+
vector_field_name = "vector_"+distance_type.name.lower()
96+
97+
values = [
98+
("Power of red", [1, 1]),
99+
("Blueberry", [2, 2]),
100+
("Red", [3, 3]),
101+
("Blue sea", [4, 4]),
102+
("Lightblue", [5, 5]),
103+
("Red apple", [6, 6]),
104+
("Hundred", [7, 7]),
105+
("Tired", [8, 8]),
106+
("Power of blue", [9, 9])
107+
]
108+
for value in values:
109+
entity = VectorEntity()
110+
setattr(entity, "name", value[0])
111+
setattr(entity, vector_field_name, value[1])
112+
box.put(entity)
113+
100114
assert box.count() == 9
101115

102116
# Test condition + NN search
103117
qb = box.query()
104-
qb.nearest_neighbors_f32("vector", [4.1, 4.2], 6)
118+
qb.nearest_neighbors_f32(vector_field_name, [4.1, 4.2], 6)
105119
qb.contains_string("name", "red", case_sensitive=False)
106120
query = qb.build()
107121
# 4, 5, 3, 6, 2, 7
@@ -121,7 +135,7 @@ def test_combined_nn_search():
121135

122136
# Regular condition + NN search
123137
qb = box.query()
124-
qb.nearest_neighbors_f32("vector", [9.2, 8.9], 7)
138+
qb.nearest_neighbors_f32(vector_field_name, [9.2, 8.9], 7)
125139
qb.starts_with_string("name", "Blue", case_sensitive=True)
126140
query = qb.build()
127141

@@ -131,7 +145,7 @@ def test_combined_nn_search():
131145

132146
# Regular condition + NN search
133147
qb = box.query()
134-
qb.nearest_neighbors_f32("vector", [7.7, 7.7], 8)
148+
qb.nearest_neighbors_f32(vector_field_name, [7.7, 7.7], 8)
135149
qb.contains_string("name", "blue", case_sensitive=False)
136150
query = qb.build()
137151
# 8, 7, 9, 6, 5, 4, 3, 2
@@ -157,3 +171,10 @@ def test_combined_nn_search():
157171
assert len(search_results) == 2
158172
assert search_results[0] == 4
159173
assert search_results[1] == 5
174+
175+
176+
def test_combined_nn_search():
177+
""" Tests NN search combined with regular query conditions, offset and limit. """
178+
distance_type = HnswDistanceType.EUCLIDEAN
179+
_test_combined_nn_search(distance_type)
180+
# TODO: Cosine, DotProduct diverges see below

tests/test_query.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def test_basics():
1616
box_test_entity.put(TestEntity(str="bar", int64=456))
1717

1818
box_vector_entity = objectbox.Box(ob, VectorEntity)
19-
box_vector_entity.put(VectorEntity(name="Object 1", vector=[1, 1]))
20-
box_vector_entity.put(VectorEntity(name="Object 2", vector=[2, 2]))
21-
box_vector_entity.put(VectorEntity(name="Object 3", vector=[3, 3]))
19+
box_vector_entity.put(VectorEntity(name="Object 1", vector_euclidean=[1, 1]))
20+
box_vector_entity.put(VectorEntity(name="Object 2", vector_euclidean=[2, 2]))
21+
box_vector_entity.put(VectorEntity(name="Object 3", vector_euclidean=[3, 3]))
2222

2323
# String query
2424
str_prop: Property = TestEntity.get_property("str")
@@ -98,7 +98,7 @@ def test_basics():
9898
assert query.remove() == 1
9999

100100
# NN query
101-
vector_prop: Property = VectorEntity.get_property("vector")
101+
vector_prop: Property = VectorEntity.get_property("vector_euclidean")
102102

103103
query = box_vector_entity.query(vector_prop.nearest_neighbor([2.1, 2.1], 2)).build()
104104
assert query.count() == 2
@@ -258,11 +258,11 @@ def test_set_parameter():
258258
box_test_entity.put(TestEntity(str="Barrakuda", int64=4, int32=386, int8=60))
259259

260260
box_vector_entity = objectbox.Box(db, VectorEntity)
261-
box_vector_entity.put(VectorEntity(name="Object 1", vector=[1, 1]))
262-
box_vector_entity.put(VectorEntity(name="Object 2", vector=[2, 2]))
263-
box_vector_entity.put(VectorEntity(name="Object 3", vector=[3, 3]))
264-
box_vector_entity.put(VectorEntity(name="Object 4", vector=[4, 4]))
265-
box_vector_entity.put(VectorEntity(name="Object 5", vector=[5, 5]))
261+
box_vector_entity.put(VectorEntity(name="Object 1", vector_euclidean=[1, 1]))
262+
box_vector_entity.put(VectorEntity(name="Object 2", vector_euclidean=[2, 2]))
263+
box_vector_entity.put(VectorEntity(name="Object 3", vector_euclidean=[3, 3]))
264+
box_vector_entity.put(VectorEntity(name="Object 4", vector_euclidean=[4, 4]))
265+
box_vector_entity.put(VectorEntity(name="Object 5", vector_euclidean=[5, 5]))
266266

267267
qb = box_test_entity.query()
268268
qb.starts_with_string("str", "fo", case_sensitive=False)
@@ -280,22 +280,22 @@ def test_set_parameter():
280280
assert query.find_ids() == [3]
281281

282282
qb = box_vector_entity.query()
283-
qb.nearest_neighbors_f32("vector", [3.4, 3.4], 3)
283+
qb.nearest_neighbors_f32("vector_euclidean", [3.4, 3.4], 3)
284284
query = qb.build()
285285
assert query.find_ids() == sorted([3, 4, 2])
286286

287287
# set_parameter_vector_f32
288288
# set_parameter_int (NN count)
289-
query.set_parameter_vector_f32("vector", [4.9, 4.9])
289+
query.set_parameter_vector_f32("vector_euclidean", [4.9, 4.9])
290290
assert query.find_ids() == sorted([5, 4, 3])
291291

292-
query.set_parameter_vector_f32("vector", [0, 0])
292+
query.set_parameter_vector_f32("vector_euclidean", [0, 0])
293293
assert query.find_ids() == sorted([1, 2, 3])
294294

295-
query.set_parameter_vector_f32("vector", [2.5, 2.1])
295+
query.set_parameter_vector_f32("vector_euclidean", [2.5, 2.1])
296296
assert query.find_ids() == sorted([2, 3, 1])
297297

298-
query.set_parameter_int("vector", 2)
298+
query.set_parameter_int("vector_euclidean", 2)
299299
assert query.find_ids() == sorted([2, 3])
300300

301301

@@ -307,11 +307,11 @@ def test_set_parameter_alias():
307307
box.put(TestEntity(str="FooBar", int64=10, int32=49, int8=45))
308308

309309
box_vector = objectbox.Box(db, VectorEntity)
310-
box_vector.put(VectorEntity(name="Object 1", vector=[1, 1]))
311-
box_vector.put(VectorEntity(name="Object 2", vector=[2, 2]))
312-
box_vector.put(VectorEntity(name="Object 3", vector=[3, 3]))
313-
box_vector.put(VectorEntity(name="Object 4", vector=[4, 4]))
314-
box_vector.put(VectorEntity(name="Object 5", vector=[5, 5]))
310+
box_vector.put(VectorEntity(name="Object 1", vector_euclidean=[1, 1]))
311+
box_vector.put(VectorEntity(name="Object 2", vector_euclidean=[2, 2]))
312+
box_vector.put(VectorEntity(name="Object 3", vector_euclidean=[3, 3]))
313+
box_vector.put(VectorEntity(name="Object 4", vector_euclidean=[4, 4]))
314+
box_vector.put(VectorEntity(name="Object 5", vector_euclidean=[5, 5]))
315315

316316
str_prop: Property = TestEntity.get_property("str")
317317
int32_prop: Property = TestEntity.get_property("int32")
@@ -354,7 +354,7 @@ def test_set_parameter_alias():
354354
assert query.find()[0].str == "FooBar"
355355

356356
# Test set parameter alias on vector
357-
vector_prop: Property = VectorEntity.get_property("vector")
357+
vector_prop: Property = VectorEntity.get_property("vector_euclidean")
358358

359359
query = box_vector.query(vector_prop.nearest_neighbor([3.4, 3.4], 3).alias("nearest_neighbour_filter")).build()
360360
assert query.count() == 3

0 commit comments

Comments
 (0)