@@ -206,3 +206,215 @@ impl Operator<RankInput, RankOutput> for Rank {
206
206
Ok ( RankOutput { ranks } )
207
207
}
208
208
}
209
+
210
+ #[ cfg( test) ]
211
+ mod tests {
212
+ use super :: * ;
213
+
214
+ #[ tokio:: test]
215
+ async fn test_rank_with_knn_results ( ) {
216
+ // Setup KNN results
217
+ let mut knn_results = HashMap :: new ( ) ;
218
+ let query = KnnQuery {
219
+ embedding : chroma_types:: operator:: QueryVector :: Dense ( vec ! [ 0.1 , 0.2 , 0.3 ] ) ,
220
+ key : String :: new ( ) ,
221
+ limit : 3 ,
222
+ } ;
223
+ knn_results. insert (
224
+ query. clone ( ) ,
225
+ vec ! [
226
+ RecordMeasure {
227
+ offset_id: 1 ,
228
+ measure: 0.9 ,
229
+ } ,
230
+ RecordMeasure {
231
+ offset_id: 2 ,
232
+ measure: 0.7 ,
233
+ } ,
234
+ RecordMeasure {
235
+ offset_id: 3 ,
236
+ measure: 0.5 ,
237
+ } ,
238
+ ] ,
239
+ ) ;
240
+
241
+ // Test simple KNN rank
242
+ let rank = Rank :: Knn {
243
+ embedding : query. embedding . clone ( ) ,
244
+ key : String :: new ( ) ,
245
+ limit : query. limit ,
246
+ default : None ,
247
+ ordinal : false ,
248
+ } ;
249
+ let input = RankInput {
250
+ knn_results,
251
+ blockfile_provider : BlockfileProvider :: new_memory ( ) ,
252
+ } ;
253
+
254
+ let output = rank. run ( & input) . await . expect ( "Rank should succeed" ) ;
255
+ assert_eq ! ( output. ranks. len( ) , 3 ) ;
256
+ assert_eq ! ( output. ranks[ 0 ] . offset_id, 1 ) ;
257
+ assert_eq ! ( output. ranks[ 0 ] . measure, 0.9 ) ;
258
+ }
259
+
260
+ #[ tokio:: test]
261
+ async fn test_rank_arithmetic_operations ( ) {
262
+ // Setup two KNN queries
263
+ let mut knn_results = HashMap :: new ( ) ;
264
+ let query1 = KnnQuery {
265
+ embedding : chroma_types:: operator:: QueryVector :: Dense ( vec ! [ 0.1 ] ) ,
266
+ key : String :: new ( ) ,
267
+ limit : 2 ,
268
+ } ;
269
+ let query2 = KnnQuery {
270
+ embedding : chroma_types:: operator:: QueryVector :: Sparse ( chroma_types:: SparseVector {
271
+ indices : vec ! [ 0 ] ,
272
+ values : vec ! [ 1.0 ] ,
273
+ } ) ,
274
+ key : "sparse" . to_string ( ) ,
275
+ limit : 2 ,
276
+ } ;
277
+
278
+ knn_results. insert (
279
+ query1. clone ( ) ,
280
+ vec ! [
281
+ RecordMeasure {
282
+ offset_id: 1 ,
283
+ measure: 0.8 ,
284
+ } ,
285
+ RecordMeasure {
286
+ offset_id: 2 ,
287
+ measure: 0.6 ,
288
+ } ,
289
+ ] ,
290
+ ) ;
291
+ knn_results. insert (
292
+ query2. clone ( ) ,
293
+ vec ! [
294
+ RecordMeasure {
295
+ offset_id: 1 ,
296
+ measure: 0.4 ,
297
+ } ,
298
+ RecordMeasure {
299
+ offset_id: 3 ,
300
+ measure: 0.2 ,
301
+ } ,
302
+ ] ,
303
+ ) ;
304
+
305
+ // Test summation
306
+ let rank = Rank :: Summation ( vec ! [
307
+ Rank :: Knn {
308
+ embedding: query1. embedding. clone( ) ,
309
+ key: String :: new( ) ,
310
+ limit: query1. limit,
311
+ default : None ,
312
+ ordinal: false ,
313
+ } ,
314
+ Rank :: Knn {
315
+ embedding: query2. embedding. clone( ) ,
316
+ key: "sparse" . to_string( ) ,
317
+ limit: query2. limit,
318
+ default : None ,
319
+ ordinal: false ,
320
+ } ,
321
+ ] ) ;
322
+ let input = RankInput {
323
+ knn_results : knn_results. clone ( ) ,
324
+ blockfile_provider : BlockfileProvider :: new_memory ( ) ,
325
+ } ;
326
+
327
+ let output = rank. run ( & input) . await . expect ( "Rank should succeed" ) ;
328
+ // Record 1 appears in both: 0.8 + 0.4 = 1.2
329
+ assert_eq ! ( output. ranks[ 0 ] . offset_id, 1 ) ;
330
+ assert_eq ! ( output. ranks[ 0 ] . measure, 1.2 ) ;
331
+
332
+ // Test multiplication with constant
333
+ let rank = Rank :: Multiplication ( vec ! [
334
+ Rank :: Knn {
335
+ embedding: query1. embedding. clone( ) ,
336
+ key: String :: new( ) ,
337
+ limit: query1. limit,
338
+ default : None ,
339
+ ordinal: false ,
340
+ } ,
341
+ Rank :: Value ( 0.5 ) ,
342
+ ] ) ;
343
+ let input = RankInput {
344
+ knn_results,
345
+ blockfile_provider : BlockfileProvider :: new_memory ( ) ,
346
+ } ;
347
+
348
+ let output = rank. run ( & input) . await . expect ( "Rank should succeed" ) ;
349
+ assert_eq ! ( output. ranks[ 0 ] . offset_id, 1 ) ;
350
+ assert_eq ! ( output. ranks[ 0 ] . measure, 0.4 ) ; // 0.8 * 0.5
351
+ }
352
+
353
+ #[ tokio:: test]
354
+ async fn test_rank_min_max_functions ( ) {
355
+ let mut knn_results = HashMap :: new ( ) ;
356
+ let query = KnnQuery {
357
+ embedding : chroma_types:: operator:: QueryVector :: Dense ( vec ! [ 0.1 ] ) ,
358
+ key : String :: new ( ) ,
359
+ limit : 2 ,
360
+ } ;
361
+
362
+ knn_results. insert (
363
+ query. clone ( ) ,
364
+ vec ! [
365
+ RecordMeasure {
366
+ offset_id: 1 ,
367
+ measure: 0.8 ,
368
+ } ,
369
+ RecordMeasure {
370
+ offset_id: 2 ,
371
+ measure: 0.3 ,
372
+ } ,
373
+ ] ,
374
+ ) ;
375
+
376
+ // Test max
377
+ let rank = Rank :: Maximum ( vec ! [
378
+ Rank :: Knn {
379
+ embedding: query. embedding. clone( ) ,
380
+ key: String :: new( ) ,
381
+ limit: query. limit,
382
+ default : None ,
383
+ ordinal: false ,
384
+ } ,
385
+ Rank :: Value ( 0.5 ) ,
386
+ ] ) ;
387
+ let input = RankInput {
388
+ knn_results : knn_results. clone ( ) ,
389
+ blockfile_provider : BlockfileProvider :: new_memory ( ) ,
390
+ } ;
391
+
392
+ let output = rank. run ( & input) . await . expect ( "Rank should succeed" ) ;
393
+ assert_eq ! ( output. ranks[ 0 ] . offset_id, 1 ) ;
394
+ assert_eq ! ( output. ranks[ 0 ] . measure, 0.8 ) ; // max(0.8, 0.5) = 0.8
395
+ assert_eq ! ( output. ranks[ 1 ] . offset_id, 2 ) ;
396
+ assert_eq ! ( output. ranks[ 1 ] . measure, 0.5 ) ; // max(0.3, 0.5) = 0.5
397
+
398
+ // Test min
399
+ let rank = Rank :: Minimum ( vec ! [
400
+ Rank :: Knn {
401
+ embedding: query. embedding. clone( ) ,
402
+ key: String :: new( ) ,
403
+ limit: query. limit,
404
+ default : None ,
405
+ ordinal: false ,
406
+ } ,
407
+ Rank :: Value ( 0.5 ) ,
408
+ ] ) ;
409
+ let input = RankInput {
410
+ knn_results,
411
+ blockfile_provider : BlockfileProvider :: new_memory ( ) ,
412
+ } ;
413
+
414
+ let output = rank. run ( & input) . await . expect ( "Rank should succeed" ) ;
415
+ assert_eq ! ( output. ranks[ 0 ] . offset_id, 1 ) ;
416
+ assert_eq ! ( output. ranks[ 0 ] . measure, 0.5 ) ; // min(0.8, 0.5) = 0.5
417
+ assert_eq ! ( output. ranks[ 1 ] . offset_id, 2 ) ;
418
+ assert_eq ! ( output. ranks[ 1 ] . measure, 0.3 ) ; // min(0.3, 0.5) = 0.3
419
+ }
420
+ }
0 commit comments