Skip to content

Commit 18c41fd

Browse files
committed
Improve: Type-casting logic
1 parent a334e99 commit 18c41fd

File tree

11 files changed

+654
-534
lines changed

11 files changed

+654
-534
lines changed

include/simsimd/curved.h

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim
9696
SIMSIMD_PUBLIC void simsimd_bilinear_##input_type##_##name( \
9797
simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \
9898
simsimd_size_t n, simsimd_distance_t *result) { \
99-
simsimd_##accumulator_type##_t sum = 0; \
99+
simsimd_##accumulator_type##_t sum = 0, a_i, b_j, c_ij; \
100100
for (simsimd_size_t i = 0; i != n; ++i) { \
101101
simsimd_##accumulator_type##_t partial = 0; \
102-
simsimd_##accumulator_type##_t a_i = load_and_convert(a + i); \
102+
load_and_convert(a + i, &a_i); \
103103
for (simsimd_size_t j = 0; j != n; ++j) { \
104-
simsimd_##accumulator_type##_t b_j = load_and_convert(b + j); \
105-
simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \
104+
load_and_convert(b + j, &b_j); \
105+
load_and_convert(c + i * n + j, &c_ij); \
106106
partial += c_ij * b_j; \
107107
} \
108108
sum += a_i * partial; \
@@ -114,40 +114,44 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_f16_sapphire(simsimd_f16_t const* a, sim
114114
SIMSIMD_PUBLIC void simsimd_mahalanobis_##input_type##_##name( \
115115
simsimd_##input_type##_t const *a, simsimd_##input_type##_t const *b, simsimd_##input_type##_t const *c, \
116116
simsimd_size_t n, simsimd_distance_t *result) { \
117-
simsimd_##accumulator_type##_t sum = 0; \
117+
simsimd_##accumulator_type##_t sum = 0, a_i, a_j, b_i, b_j, c_ij; \
118118
for (simsimd_size_t i = 0; i != n; ++i) { \
119119
simsimd_##accumulator_type##_t partial = 0; \
120-
simsimd_##accumulator_type##_t diff_i = load_and_convert(a + i) - load_and_convert(b + i); \
120+
load_and_convert(a + i, &a_i); \
121+
load_and_convert(b + i, &b_i); \
122+
simsimd_##accumulator_type##_t diff_i = a_i - b_i; \
121123
for (simsimd_size_t j = 0; j != n; ++j) { \
122-
simsimd_##accumulator_type##_t diff_j = load_and_convert(a + j) - load_and_convert(b + j); \
123-
simsimd_##accumulator_type##_t c_ij = load_and_convert(c + i * n + j); \
124+
load_and_convert(a + j, &a_j); \
125+
load_and_convert(b + j, &b_j); \
126+
simsimd_##accumulator_type##_t diff_j = a_j - b_j; \
127+
load_and_convert(c + i * n + j, &c_ij); \
124128
partial += c_ij * diff_j; \
125129
} \
126130
sum += diff_i * partial; \
127131
} \
128132
*result = (simsimd_distance_t)SIMSIMD_SQRT(sum); \
129133
}
130134

131-
SIMSIMD_MAKE_BILINEAR(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f64_serial
132-
SIMSIMD_MAKE_MAHALANOBIS(serial, f64, f64, SIMSIMD_DEREFERENCE) // simsimd_mahalanobis_f64_serial
135+
SIMSIMD_MAKE_BILINEAR(serial, f64, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_bilinear_f64_serial
136+
SIMSIMD_MAKE_MAHALANOBIS(serial, f64, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_mahalanobis_f64_serial
133137

134-
SIMSIMD_MAKE_BILINEAR(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f32_serial
135-
SIMSIMD_MAKE_MAHALANOBIS(serial, f32, f32, SIMSIMD_DEREFERENCE) // simsimd_mahalanobis_f32_serial
138+
SIMSIMD_MAKE_BILINEAR(serial, f32, f32, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_bilinear_f32_serial
139+
SIMSIMD_MAKE_MAHALANOBIS(serial, f32, f32, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_mahalanobis_f32_serial
136140

137-
SIMSIMD_MAKE_BILINEAR(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_bilinear_f16_serial
138-
SIMSIMD_MAKE_MAHALANOBIS(serial, f16, f32, SIMSIMD_F16_TO_F32) // simsimd_mahalanobis_f16_serial
141+
SIMSIMD_MAKE_BILINEAR(serial, f16, f32, simsimd_f16_to_f32) // simsimd_bilinear_f16_serial
142+
SIMSIMD_MAKE_MAHALANOBIS(serial, f16, f32, simsimd_f16_to_f32) // simsimd_mahalanobis_f16_serial
139143

140-
SIMSIMD_MAKE_BILINEAR(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_bilinear_bf16_serial
141-
SIMSIMD_MAKE_MAHALANOBIS(serial, bf16, f32, SIMSIMD_BF16_TO_F32) // simsimd_mahalanobis_bf16_serial
144+
SIMSIMD_MAKE_BILINEAR(serial, bf16, f32, simsimd_bf16_to_f32) // simsimd_bilinear_bf16_serial
145+
SIMSIMD_MAKE_MAHALANOBIS(serial, bf16, f32, simsimd_bf16_to_f32) // simsimd_mahalanobis_bf16_serial
142146

143-
SIMSIMD_MAKE_BILINEAR(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_bilinear_f32_accurate
144-
SIMSIMD_MAKE_MAHALANOBIS(accurate, f32, f64, SIMSIMD_DEREFERENCE) // simsimd_mahalanobis_f32_accurate
147+
SIMSIMD_MAKE_BILINEAR(accurate, f32, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_bilinear_f32_accurate
148+
SIMSIMD_MAKE_MAHALANOBIS(accurate, f32, f64, _SIMSIMD_ASSIGN_1_TO_2) // simsimd_mahalanobis_f32_accurate
145149

146-
SIMSIMD_MAKE_BILINEAR(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_bilinear_f16_accurate
147-
SIMSIMD_MAKE_MAHALANOBIS(accurate, f16, f64, SIMSIMD_F16_TO_F32) // simsimd_mahalanobis_f16_accurate
150+
SIMSIMD_MAKE_BILINEAR(accurate, f16, f64, _simsimd_f16_to_f64) // simsimd_bilinear_f16_accurate
151+
SIMSIMD_MAKE_MAHALANOBIS(accurate, f16, f64, _simsimd_f16_to_f64) // simsimd_mahalanobis_f16_accurate
148152

149-
SIMSIMD_MAKE_BILINEAR(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_bilinear_bf16_accurate
150-
SIMSIMD_MAKE_MAHALANOBIS(accurate, bf16, f64, SIMSIMD_BF16_TO_F32) // simsimd_mahalanobis_bf16_accurate
153+
SIMSIMD_MAKE_BILINEAR(accurate, bf16, f64, _simsimd_bf16_to_f64) // simsimd_bilinear_bf16_accurate
154+
SIMSIMD_MAKE_MAHALANOBIS(accurate, bf16, f64, _simsimd_bf16_to_f64) // simsimd_mahalanobis_bf16_accurate
151155

152156
#if _SIMSIMD_TARGET_ARM
153157
#if SIMSIMD_TARGET_NEON
@@ -313,7 +317,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
313317
simsimd_bf16_t const *c, simsimd_size_t n, simsimd_distance_t *result) {
314318
float32x4_t sum_vec = vdupq_n_f32(0);
315319
for (simsimd_size_t i = 0; i != n; ++i) {
316-
float32x4_t a_vec = vdupq_n_f32(simsimd_bf16_to_f32(a + i));
320+
simsimd_f32_t a_i;
321+
simsimd_bf16_to_f32(a + i, &a_i);
322+
float32x4_t a_vec = vdupq_n_f32(a_i);
317323
float32x4_t partial_sum_vec = vdupq_n_f32(0);
318324
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
319325
bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const *)(b + j));
@@ -329,7 +335,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_neon(simsimd_bf16_t const *a, simsimd_
329335
simsimd_size_t tail_start = n - tail_length;
330336
if (tail_length) {
331337
for (simsimd_size_t i = 0; i != n; ++i) {
332-
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
338+
simsimd_f32_t a_i;
339+
simsimd_bf16_to_f32(a + i, &a_i);
333340
bfloat16x8_t b_vec = _simsimd_partial_load_bf16x8_neon(b + tail_start, tail_length);
334341
bfloat16x8_t c_vec = _simsimd_partial_load_bf16x8_neon(c + i * n + tail_start, tail_length);
335342
simsimd_f32_t partial_sum = vaddvq_f32(vbfdotq_f32(vdupq_n_f32(0), b_vec, c_vec));
@@ -345,8 +352,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
345352
simsimd_distance_t *result) {
346353
float32x4_t sum_vec = vdupq_n_f32(0);
347354
for (simsimd_size_t i = 0; i != n; ++i) {
348-
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
349-
simsimd_f32_t b_i = simsimd_bf16_to_f32(b + i);
355+
simsimd_f32_t a_i, b_i;
356+
simsimd_bf16_to_f32(a + i, &a_i);
357+
simsimd_bf16_to_f32(b + i, &b_i);
350358
float32x4_t diff_i_vec = vdupq_n_f32(a_i - b_i);
351359
float32x4_t partial_sum_vec = vdupq_n_f32(0);
352360
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
@@ -376,8 +384,9 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_neon(simsimd_bf16_t const *a, simsi
376384
simsimd_size_t tail_start = n - tail_length;
377385
if (tail_length) {
378386
for (simsimd_size_t i = 0; i != n; ++i) {
379-
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
380-
simsimd_f32_t b_i = simsimd_bf16_to_f32(b + i);
387+
simsimd_f32_t a_i, b_i;
388+
simsimd_bf16_to_f32(a + i, &a_i);
389+
simsimd_bf16_to_f32(b + i, &b_i);
381390
simsimd_f32_t diff_i = a_i - b_i;
382391
bfloat16x8_t a_j_vec = _simsimd_partial_load_bf16x8_neon(a + tail_start, tail_length);
383392
bfloat16x8_t b_j_vec = _simsimd_partial_load_bf16x8_neon(b + tail_start, tail_length);
@@ -489,7 +498,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
489498
__m256 sum_vec = _mm256_setzero_ps();
490499
for (simsimd_size_t i = 0; i != n; ++i) {
491500
// The `simsimd_bf16_to_f32` is cheaper than `_simsimd_bf16x8_to_f32x8_haswell`
492-
__m256 a_vec = _mm256_set1_ps(simsimd_bf16_to_f32(a + i));
501+
simsimd_f32_t a_i;
502+
simsimd_bf16_to_f32(a + i, &a_i);
503+
__m256 a_vec = _mm256_set1_ps(a_i);
493504
__m256 partial_sum_vec = _mm256_setzero_ps();
494505
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
495506
__m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell(_mm_lddqu_si128((__m128i const *)(b + j)));
@@ -505,7 +516,8 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_haswell(simsimd_bf16_t const *a, simsi
505516
simsimd_size_t tail_start = n - tail_length;
506517
if (tail_length) {
507518
for (simsimd_size_t i = 0; i != n; ++i) {
508-
simsimd_f32_t a_i = simsimd_bf16_to_f32(a + i);
519+
simsimd_f32_t a_i;
520+
simsimd_bf16_to_f32(a + i, &a_i);
509521
__m256 b_vec = _simsimd_bf16x8_to_f32x8_haswell( //
510522
_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length));
511523
__m256 c_vec = _simsimd_bf16x8_to_f32x8_haswell( //
@@ -523,9 +535,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
523535
simsimd_distance_t *result) {
524536
__m256 sum_vec = _mm256_setzero_ps();
525537
for (simsimd_size_t i = 0; i != n; ++i) {
526-
__m256 diff_i_vec = _mm256_sub_ps( //
527-
_mm256_set1_ps(simsimd_bf16_to_f32(a + i)), //
528-
_mm256_set1_ps(simsimd_bf16_to_f32(b + i)));
538+
simsimd_f32_t a_i, b_i;
539+
simsimd_bf16_to_f32(a + i, &a_i);
540+
simsimd_bf16_to_f32(b + i, &b_i);
541+
__m256 diff_i_vec = _mm256_set1_ps(a_i - b_i);
529542
__m256 partial_sum_vec = _mm256_setzero_ps();
530543
for (simsimd_size_t j = 0; j + 8 <= n; j += 8) {
531544
__m256 diff_j_vec = _mm256_sub_ps( //
@@ -543,7 +556,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_haswell(simsimd_bf16_t const *a, si
543556
simsimd_size_t tail_start = n - tail_length;
544557
if (tail_length) {
545558
for (simsimd_size_t i = 0; i != n; ++i) {
546-
simsimd_f32_t diff_i = simsimd_bf16_to_f32(a + i) - simsimd_bf16_to_f32(b + i);
559+
simsimd_f32_t a_i, b_i;
560+
simsimd_bf16_to_f32(a + i, &a_i);
561+
simsimd_bf16_to_f32(b + i, &b_i);
562+
simsimd_f32_t diff_i = a_i - b_i;
547563
__m256 diff_j_vec = _mm256_sub_ps( //
548564
_simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(a + tail_start, tail_length)),
549565
_simsimd_bf16x8_to_f32x8_haswell(_simsimd_partial_load_bf16x8_haswell(b + tail_start, tail_length)));
@@ -651,7 +667,9 @@ SIMSIMD_PUBLIC void simsimd_bilinear_bf16_genoa(simsimd_bf16_t const *a, simsimd
651667
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);
652668

653669
for (simsimd_size_t i = 0; i != n; ++i) {
654-
__m512 a_vec = _mm512_set1_ps(simsimd_bf16_to_f32(a + i));
670+
simsimd_f32_t a_i;
671+
simsimd_bf16_to_f32(a + i, &a_i);
672+
__m512 a_vec = _mm512_set1_ps(a_i);
655673
__m512 partial_sum_vec = _mm512_setzero_ps();
656674
__m512i b_vec, c_vec;
657675
simsimd_size_t j = 0;
@@ -683,7 +701,10 @@ SIMSIMD_PUBLIC void simsimd_mahalanobis_bf16_genoa(simsimd_bf16_t const *a, sims
683701
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, tail_length);
684702

685703
for (simsimd_size_t i = 0; i != n; ++i) {
686-
__m512 diff_i_vec = _mm512_set1_ps(simsimd_bf16_to_f32(a + i) - simsimd_bf16_to_f32(b + i));
704+
simsimd_f32_t a_i, b_i;
705+
simsimd_bf16_to_f32(a + i, &a_i);
706+
simsimd_bf16_to_f32(b + i, &b_i);
707+
__m512 diff_i_vec = _mm512_set1_ps(a_i - b_i);
687708
__m512 partial_sum_vec = _mm512_setzero_ps();
688709
__m512i a_j_vec, b_j_vec, diff_j_vec, c_vec;
689710
simsimd_size_t j = 0;

0 commit comments

Comments
 (0)