@@ -96,13 +96,13 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim
96
96
SIMSIMD_PUBLIC void simsimd_bilinear_##input_type##_##name( \
97
97
simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \
98
98
simsimd_size_t n, simsimd_distance_t *result) { \
99
- simsimd_##accumulator_type##_t sum = 0; \
99
+ simsimd_##accumulator_type##_t sum = 0, a_i, b_j, c_ij; \
100
100
for (simsimd_size_t i = 0; i != n; ++i) { \
101
101
simsimd_##accumulator_type##_t partial = 0; \
102
- simsimd_##accumulator_type##_t a_i = load_and_convert(a + i); \
102
+ load_and_convert(a + i, &a_i); \
103
103
for (simsimd_size_t j = 0; j != n; ++j) { \
104
- simsimd_##accumulator_type##_t b_j = load_and_convert(b + j); \
105
- simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \
104
+ load_and_convert(b + j, &b_j); \
105
+ load_and_convert(c + i * n + j, &c_ij); \
106
106
partial += c_ij * b_j; \
107
107
} \
108
108
sum += a_i * partial; \
@@ -114,40 +114,44 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim
114
114
SIMSIMD_PUBLIC void simsimd_mahalanobis_##input_type##_##name( \
115
115
simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \
116
116
simsimd_size_t n, simsimd_distance_t *result) { \
117
- simsimd_##accumulator_type##_t sum = 0; \
117
+ simsimd_##accumulator_type##_t sum = 0, a_i, a_j, b_i, b_j, c_ij; \
118
118
for (simsimd_size_t i = 0; i != n; ++i) { \
119
119
simsimd_##accumulator_type##_t partial = 0; \
120
- simsimd_##accumulator_type##_t diff_i = load_and_convert(a + i) - load_and_convert(b + i); \
120
+ load_and_convert(a + i, &a_i); \
121
+ load_and_convert(b + i, &b_i); \
122
+ simsimd_##accumulator_type##_t diff_i = a_i - b_i; \
121
123
for (simsimd_size_t j = 0; j != n; ++j) { \
122
- simsimd_##accumulator_type##_t diff_j = load_and_convert(a + j) - load_and_convert(b + j); \
123
- simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \
124
+ load_and_convert(a + j, &a_j); \
125
+ load_and_convert(b + j, &b_j); \
126
+ simsimd_##accumulator_type##_t diff_j = a_j - b_j; \
127
+ load_and_convert(c + i * n + j, &c_ij); \
124
128
partial += c_ij * diff_j; \
125
129
} \
126
130
sum += diff_i * partial; \
127
131
} \
128
132
*result = (simsimd_distance_t)SIMSIMD_SQRT(sum); \
129
133
}
130
134
131
- SIMSIMD_MAKE_BILINEAR (serial , f64 , f64 , SIMSIMD_DEREFERENCE ) // simsimd_bilinear_f64_serial
132
- SIMSIMD_MAKE_MAHALANOBIS (serial , f64 , f64 , SIMSIMD_DEREFERENCE ) // simsimd_mahalanobis_f64_serial
135
+ SIMSIMD_MAKE_BILINEAR (serial , f64 , f64 , _SIMSIMD_ASSIGN_1_TO_2 ) // simsimd_bilinear_f64_serial
136
+ SIMSIMD_MAKE_MAHALANOBIS (serial , f64 , f64 , _SIMSIMD_ASSIGN_1_TO_2 ) // simsimd_mahalanobis_f64_serial
133
137
134
- SIMSIMD_MAKE_BILINEAR (serial , f32 , f32 , SIMSIMD_DEREFERENCE ) // simsimd_bilinear_f32_serial
135
- SIMSIMD_MAKE_MAHALANOBIS (serial , f32 , f32 , SIMSIMD_DEREFERENCE ) // simsimd_mahalanobis_f32_serial
138
+ SIMSIMD_MAKE_BILINEAR (serial , f32 , f32 , _SIMSIMD_ASSIGN_1_TO_2 ) // simsimd_bilinear_f32_serial
139
+ SIMSIMD_MAKE_MAHALANOBIS (serial , f32 , f32 , _SIMSIMD_ASSIGN_1_TO_2 ) // simsimd_mahalanobis_f32_serial
136
140
137
- SIMSIMD_MAKE_BILINEAR (serial , f16 , f32 , SIMSIMD_F16_TO_F32 ) // simsimd_bilinear_f16_serial
138
- SIMSIMD_MAKE_MAHALANOBIS (serial , f16 , f32 , SIMSIMD_F16_TO_F32 ) // simsimd_mahalanobis_f16_serial
141
+ SIMSIMD_MAKE_BILINEAR (serial , f16 , f32 , simsimd_f16_to_f32 ) // simsimd_bilinear_f16_serial
142
+ SIMSIMD_MAKE_MAHALANOBIS (serial , f16 , f32 , simsimd_f16_to_f32 ) // simsimd_mahalanobis_f16_serial
139
143
140
- SIMSIMD_MAKE_BILINEAR (serial , bf16 , f32 , SIMSIMD_BF16_TO_F32 ) // simsimd_bilinear_bf16_serial
141
- SIMSIMD_MAKE_MAHALANOBIS (serial , bf16 , f32 , SIMSIMD_BF16_TO_F32 ) // simsimd_mahalanobis_bf16_serial
144
+ SIMSIMD_MAKE_BILINEAR (serial , bf16 , f32 , simsimd_bf16_to_f32 ) // simsimd_bilinear_bf16_serial
145
+ SIMSIMD_MAKE_MAHALANOBIS (serial , bf16 , f32 , simsimd_bf16_to_f32 ) // simsimd_mahalanobis_bf16_serial
142
146
143
- SIMSIMD_MAKE_BILINEAR (accurate , f32 , f64 , SIMSIMD_DEREFERENCE ) // simsimd_bilinear_f32_accurate
144
- SIMSIMD_MAKE_MAHALANOBIS (accurate , f32 , f64 , SIMSIMD_DEREFERENCE ) // simsimd_mahalanobis_f32_accurate
147
+ SIMSIMD_MAKE_BILINEAR (accurate , f32 , f64 , _SIMSIMD_ASSIGN_1_TO_2 ) // simsimd_bilinear_f32_accurate
148
+ SIMSIMD_MAKE_MAHALANOBIS (accurate , f32 , f64 , _SIMSIMD_ASSIGN_1_TO_2 ) // simsimd_mahalanobis_f32_accurate
145
149
146
- SIMSIMD_MAKE_BILINEAR (accurate , f16 , f64 , SIMSIMD_F16_TO_F32 ) // simsimd_bilinear_f16_accurate
147
- SIMSIMD_MAKE_MAHALANOBIS (accurate , f16 , f64 , SIMSIMD_F16_TO_F32 ) // simsimd_mahalanobis_f16_accurate
150
+ SIMSIMD_MAKE_BILINEAR (accurate , f16 , f64 , _simsimd_f16_to_f64 ) // simsimd_bilinear_f16_accurate
151
+ SIMSIMD_MAKE_MAHALANOBIS (accurate , f16 , f64 , _simsimd_f16_to_f64 ) // simsimd_mahalanobis_f16_accurate
148
152
149
- SIMSIMD_MAKE_BILINEAR (accurate , bf16 , f64 , SIMSIMD_BF16_TO_F32 ) // simsimd_bilinear_bf16_accurate
150
- SIMSIMD_MAKE_MAHALANOBIS (accurate , bf16 , f64 , SIMSIMD_BF16_TO_F32 ) // simsimd_mahalanobis_bf16_accurate
153
+ SIMSIMD_MAKE_BILINEAR (accurate , bf16 , f64 , _simsimd_bf16_to_f64 ) // simsimd_bilinear_bf16_accurate
154
+ SIMSIMD_MAKE_MAHALANOBIS (accurate , bf16 , f64 , _simsimd_bf16_to_f64 ) // simsimd_mahalanobis_bf16_accurate
151
155
152
156
#if _SIMSIMD_TARGET_ARM
153
157
#if SIMSIMD_TARGET_NEON
@@ -313,7 +317,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
313
317
simsimd_bf16_t const * c , simsimd_size_t n , simsimd_distance_t * result ) {
314
318
float32x4_t sum_vec = vdupq_n_f32 (0 );
315
319
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
316
- float32x4_t a_vec = vdupq_n_f32 (simsimd_bf16_to_f32 (a + i ));
320
+ simsimd_f32_t a_i ;
321
+ simsimd_bf16_to_f32 (a + i , & a_i );
322
+ float32x4_t a_vec = vdupq_n_f32 (a_i );
317
323
float32x4_t partial_sum_vec = vdupq_n_f32 (0 );
318
324
for (simsimd_size_t j = 0 ; j + 8 <= n ; j += 8 ) {
319
325
bfloat16x8_t b_vec = vld1q_bf16 ((simsimd_bf16_for_arm_simd_t const * )(b + j ));
@@ -329,7 +335,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
329
335
simsimd_size_t tail_start = n - tail_length ;
330
336
if (tail_length ) {
331
337
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
332
- simsimd_f32_t a_i = simsimd_bf16_to_f32 (a + i );
338
+ simsimd_f32_t a_i ;
339
+ simsimd_bf16_to_f32 (a + i , & a_i );
333
340
bfloat16x8_t b_vec = _simsimd_partial_load_bf16x8_neon (b + tail_start , tail_length );
334
341
bfloat16x8_t c_vec = _simsimd_partial_load_bf16x8_neon (c + i * n + tail_start , tail_length );
335
342
simsimd_f32_t partial_sum = vaddvq_f32 (vbfdotq_f32 (vdupq_n_f32 (0 ), b_vec , c_vec ));
@@ -345,8 +352,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
345
352
simsimd_distance_t * result ) {
346
353
float32x4_t sum_vec = vdupq_n_f32 (0 );
347
354
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
348
- simsimd_f32_t a_i = simsimd_bf16_to_f32 (a + i );
349
- simsimd_f32_t b_i = simsimd_bf16_to_f32 (b + i );
355
+ simsimd_f32_t a_i , b_i ;
356
+ simsimd_bf16_to_f32 (a + i , & a_i );
357
+ simsimd_bf16_to_f32 (b + i , & b_i );
350
358
float32x4_t diff_i_vec = vdupq_n_f32 (a_i - b_i );
351
359
float32x4_t partial_sum_vec = vdupq_n_f32 (0 );
352
360
for (simsimd_size_t j = 0 ; j + 8 <= n ; j += 8 ) {
@@ -376,8 +384,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
376
384
simsimd_size_t tail_start = n - tail_length ;
377
385
if (tail_length ) {
378
386
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
379
- simsimd_f32_t a_i = simsimd_bf16_to_f32 (a + i );
380
- simsimd_f32_t b_i = simsimd_bf16_to_f32 (b + i );
387
+ simsimd_f32_t a_i , b_i ;
388
+ simsimd_bf16_to_f32 (a + i , & a_i );
389
+ simsimd_bf16_to_f32 (b + i , & b_i );
381
390
simsimd_f32_t diff_i = a_i - b_i ;
382
391
bfloat16x8_t a_j_vec = _simsimd_partial_load_bf16x8_neon (a + tail_start , tail_length );
383
392
bfloat16x8_t b_j_vec = _simsimd_partial_load_bf16x8_neon (b + tail_start , tail_length );
@@ -489,7 +498,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
489
498
__m256 sum_vec = _mm256_setzero_ps ();
490
499
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
491
500
// The `simsimd_bf16_to_f32` is cheaper than `_simsimd_bf16x8_to_f32x8_haswell`
492
- __m256 a_vec = _mm256_set1_ps (simsimd_bf16_to_f32 (a + i ));
501
+ simsimd_f32_t a_i ;
502
+ simsimd_bf16_to_f32 (a + i , & a_i );
503
+ __m256 a_vec = _mm256_set1_ps (a_i );
493
504
__m256 partial_sum_vec = _mm256_setzero_ps ();
494
505
for (simsimd_size_t j = 0 ; j + 8 <= n ; j += 8 ) {
495
506
__m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell (_mm_lddqu_si128 ((__m128i const * )(b + j )));
@@ -505,7 +516,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
505
516
simsimd_size_t tail_start = n - tail_length ;
506
517
if (tail_length ) {
507
518
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
508
- simsimd_f32_t a_i = simsimd_bf16_to_f32 (a + i );
519
+ simsimd_f32_t a_i ;
520
+ simsimd_bf16_to_f32 (a + i , & a_i );
509
521
__m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell ( //
510
522
_simsimd_partial_load_bf16x8_haswell (b + tail_start , tail_length ));
511
523
__m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell ( //
@@ -523,9 +535,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
523
535
simsimd_distance_t * result ) {
524
536
__m256 sum_vec = _mm256_setzero_ps ();
525
537
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
526
- __m256 diff_i_vec = _mm256_sub_ps ( //
527
- _mm256_set1_ps (simsimd_bf16_to_f32 (a + i )), //
528
- _mm256_set1_ps (simsimd_bf16_to_f32 (b + i )));
538
+ simsimd_f32_t a_i , b_i ;
539
+ simsimd_bf16_to_f32 (a + i , & a_i );
540
+ simsimd_bf16_to_f32 (b + i , & b_i );
541
+ __m256 diff_i_vec = _mm256_set1_ps (a_i - b_i );
529
542
__m256 partial_sum_vec = _mm256_setzero_ps ();
530
543
for (simsimd_size_t j = 0 ; j + 8 <= n ; j += 8 ) {
531
544
__m256 diff_j_vec = _mm256_sub_ps ( //
@@ -543,7 +556,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
543
556
simsimd_size_t tail_start = n - tail_length ;
544
557
if (tail_length ) {
545
558
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
546
- simsimd_f32_t diff_i = simsimd_bf16_to_f32 (a + i ) - simsimd_bf16_to_f32 (b + i );
559
+ simsimd_f32_t a_i , b_i ;
560
+ simsimd_bf16_to_f32 (a + i , & a_i );
561
+ simsimd_bf16_to_f32 (b + i , & b_i );
562
+ simsimd_f32_t diff_i = a_i - b_i ;
547
563
__m256 diff_j_vec = _mm256_sub_ps ( //
548
564
_simsimd_bf16x8_to_f32x8_haswell (_simsimd_partial_load_bf16x8_haswell (a + tail_start , tail_length )),
549
565
_simsimd_bf16x8_to_f32x8_haswell (_simsimd_partial_load_bf16x8_haswell (b + tail_start , tail_length )));
@@ -651,7 +667,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const *a, simsimd
651
667
__mmask32 tail_mask = (__mmask32 )_bzhi_u32 (0xFFFFFFFF , tail_length );
652
668
653
669
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
654
- __m512 a_vec = _mm512_set1_ps (simsimd_bf16_to_f32 (a + i ));
670
+ simsimd_f32_t a_i ;
671
+ simsimd_bf16_to_f32 (a + i , & a_i );
672
+ __m512 a_vec = _mm512_set1_ps (a_i );
655
673
__m512 partial_sum_vec = _mm512_setzero_ps ();
656
674
__m512i b_vec , c_vec ;
657
675
simsimd_size_t j = 0 ;
@@ -683,7 +701,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, sims
683
701
__mmask32 tail_mask = (__mmask32 )_bzhi_u32 (0xFFFFFFFF , tail_length );
684
702
685
703
for (simsimd_size_t i = 0 ; i != n ; ++ i ) {
686
- __m512 diff_i_vec = _mm512_set1_ps (simsimd_bf16_to_f32 (a + i ) - simsimd_bf16_to_f32 (b + i ));
704
+ simsimd_f32_t a_i , b_i ;
705
+ simsimd_bf16_to_f32 (a + i , & a_i );
706
+ simsimd_bf16_to_f32 (b + i , & b_i );
707
+ __m512 diff_i_vec = _mm512_set1_ps (a_i - b_i );
687
708
__m512 partial_sum_vec = _mm512_setzero_ps ();
688
709
__m512i a_j_vec , b_j_vec , diff_j_vec , c_vec ;
689
710
simsimd_size_t j = 0 ;
0 commit comments