Skip to content

Commit 9c53c92

Browse files
committed
query: add nearest_neighbors_f32, find with scores functions objectbox#24
1 parent 874520b commit 9c53c92

File tree

2 files changed

+59
-16
lines changed

2 files changed

+59
-16
lines changed

objectbox/query.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def __init__(self, c_query, box: 'Box'):
2222
self._ob = box._ob
2323

2424
def find(self) -> list:
25+
""" Finds a list of objects matching query. """
2526
with self._ob.read_tx():
2627
# OBX_bytes_array*
2728
c_bytes_array_p = obx_query_find(self._c_query)
28-
2929
try:
3030
# OBX_bytes_array
3131
c_bytes_array = c_bytes_array_p.contents
@@ -36,11 +36,48 @@ def find(self) -> list:
3636
c_bytes = c_bytes_array.data[i]
3737
data = c_voidp_as_bytes(c_bytes.data, c_bytes.size)
3838
result.append(self._box._entity.unmarshal(data))
39-
4039
return result
4140
finally:
4241
obx_bytes_array_free(c_bytes_array_p)
4342

43+
def find_ids(self) -> List[int]:
44+
""" Finds a list of object IDs matching query. The result is sorted by ID (ascending order). """
45+
c_id_array_p = obx_query_find_ids(self._c_query)
46+
try:
47+
return list(c_id_array_p.contents)
48+
finally:
49+
obx_id_array_free(c_id_array_p)
50+
51+
def find_with_scores(self):
52+
""" Finds objects matching the query associated to their query score (e.g. distance in NN search).
53+
The result is sorted by score in ascending order. """
54+
c_bytes_score_array_p = obx_query_find_with_scores(self._c_query)
55+
try:
56+
# OBX_bytes_score_array
57+
c_bytes_score_array: OBX_bytes_score_array = c_bytes_score_array_p.contents
58+
result = []
59+
for i in range(c_bytes_score_array.count):
60+
# TODO implement
61+
pass
62+
return result
63+
finally:
64+
obx_bytes_score_array_free(c_bytes_score_array_p)
65+
66+
def find_ids_with_scores(self) -> List[Tuple[int, float]]:
67+
""" Finds object IDs matching the query associated to their query score (e.g. distance in NN search).
68+
The resulting list is sorted by score in ascending order. """
69+
c_id_score_array_p = obx_query_find_ids_with_scores(self._c_query)
70+
try:
71+
# OBX_id_score_array
72+
c_id_score_array: OBX_bytes_score_array = c_id_score_array_p.contents
73+
result = []
74+
for i in range(c_id_score_array.count):
75+
c_id_score: OBX_id_score = c_id_score_array.ids_scores[i]
76+
result.append((c_id_score.id, c_id_score.score))
77+
return result
78+
finally:
79+
obx_id_score_array_free(c_id_score_array_p)
80+
4481
def count(self) -> int:
4582
count = ctypes.c_uint64()
4683
obx_query_count(self._c_query, ctypes.byref(count))
@@ -55,4 +92,4 @@ def offset(self, offset: int):
5592
return obx_query_offset(self._c_query, offset)
5693

5794
def limit(self, limit: int):
58-
return obx_query_limit(self._c_query, limit)
95+
return obx_query_limit(self._c_query, limit)

objectbox/query_builder.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from objectbox.model.entity import _Entity
1+
import ctypes
2+
import numpy as np
3+
from typing import *
4+
25
from objectbox.objectbox import ObjectBox
36
from objectbox.query import Query
47
from objectbox.c import *
58

69

710
class QueryBuilder:
8-
def __init__(self, ob: ObjectBox, box: 'Box', entity: '_Entity', condition: 'QueryCondition'):
9-
if not isinstance(entity, _Entity):
10-
raise Exception("Given type is not an Entity")
11+
def __init__(self, ob: ObjectBox, box: 'Box'):
1112
self._box = box
12-
self._entity = entity
13-
self._condition = condition
14-
self._c_builder = obx_query_builder(ob._c_store, entity.id)
13+
self._entity = box._entity
14+
self._c_builder = obx_query_builder(ob._c_store, box._entity.id)
1515

1616
def close(self) -> int:
1717
return obx_qb_close(self)
@@ -85,11 +85,17 @@ def less_or_equal_int(self, property_id: int, value: int):
8585
def between_2ints(self, property_id: int, value_a: int, value_b: int):
8686
obx_qb_between_2ints(self._c_builder, property_id, value_a, value_b)
8787
return self
88-
89-
def apply_condition(self):
90-
self._condition.apply(self)
91-
88+
89+
def nearest_neighbors_f32(self, vector_property_id: int, query_vector: Union[np.ndarray, List[float]], element_count: int):
90+
if isinstance(query_vector, np.ndarray):
91+
if query_vector.dtype != np.float32:
92+
raise Exception(f"query_vector dtype must be float32")
93+
query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
94+
else: # List[float]
95+
query_vector_data = (ctypes.c_float * len(query_vector))(*query_vector)
96+
obx_qb_nearest_neighbors_f32(self._c_builder, vector_property_id, query_vector_data, element_count)
97+
return self
98+
9299
def build(self) -> Query:
93-
self.apply_condition()
94100
c_query = obx_query(self._c_builder)
95-
return Query(c_query, self._box)
101+
return Query(c_query, self._box)

0 commit comments

Comments
 (0)