Skip to content

Commit bdbf71c

Browse files
SIMD Barycentric Evaluation and Test
1 parent 8a83a9f commit bdbf71c

File tree

1 file changed

+164
-4
lines changed

1 file changed

+164
-4
lines changed

crates/prover/src/core/poly/circle/evaluation.rs

Lines changed: 164 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,126 @@ fn barycentric_eval_at_point(
233233
})
234234
}
235235

236+
fn simd_weights(log_size: u32, sample_point: CirclePoint<SecureField>) -> Col<Self, SecureField> {
237+
let domain = CanonicCoset::new(log_size).circle_domain();
238+
let weights_vec_len = domain.size().div_ceil(N_LANES);
239+
240+
for i in 0..domain.size() {
241+
if domain.at(i).into_ef() == sample_point {
242+
let mut weights = Col::<Self, SecureField>::zeros(domain.size());
243+
weights.set(i, SecureField::one());
244+
return weights;
245+
}
246+
}
247+
248+
let p_0 = domain.at(0).into_ef::<SecureField>();
249+
let weights_first_half = SecureField::one()
250+
/ (-(p_0.y + p_0.y)
251+
* coset_vanishing_derivative(
252+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
253+
p_0,
254+
));
255+
let p_0_inverse = domain.at(domain.half_coset.size()).into_ef::<SecureField>();
256+
let weights_second_half = SecureField::one()
257+
/ (-(p_0_inverse.y + p_0_inverse.y)
258+
* coset_vanishing_derivative(
259+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
260+
p_0_inverse,
261+
));
262+
263+
let domain_points_vanishing_evaluated_at_point = (0..weights_vec_len)
264+
.map(|i| {
265+
PackedSecureField::from_array(std::array::from_fn(|j| {
266+
if domain.size() <= i * N_LANES + j {
267+
SecureField::one()
268+
} else {
269+
point_vanishing(
270+
domain.at(i * N_LANES + j).into_ef::<SecureField>(),
271+
sample_point.into_ef::<SecureField>(),
272+
)
273+
}
274+
}))
275+
})
276+
.collect_vec();
277+
let mut inversed_domain_points_vanishing_evaluated_at_point =
278+
vec![unsafe { std::mem::zeroed() }; weights_vec_len];
279+
280+
batch_inverse_in_place(
281+
&domain_points_vanishing_evaluated_at_point,
282+
&mut inversed_domain_points_vanishing_evaluated_at_point,
283+
);
284+
285+
let coset_vanishing_evaluated_at_point: SecureField = coset_vanishing(
286+
CanonicCoset::new(domain.log_size()).coset,
287+
sample_point.into_ef::<SecureField>(),
288+
);
289+
290+
if weights_vec_len == 1 {
291+
return (0..N_LANES)
292+
.map(|i| {
293+
let inversed_domain_points_vanishing_evaluated_at_point =
294+
inversed_domain_points_vanishing_evaluated_at_point[0].to_array();
295+
if i < domain.size() / 2 {
296+
inversed_domain_points_vanishing_evaluated_at_point[i]
297+
* (weights_first_half * coset_vanishing_evaluated_at_point)
298+
} else {
299+
inversed_domain_points_vanishing_evaluated_at_point[i]
300+
* (weights_second_half * coset_vanishing_evaluated_at_point)
301+
}
302+
})
303+
.collect();
304+
}
305+
306+
let weights: Col<Self, SecureField> = (0..weights_vec_len)
307+
.map(|i| {
308+
if i < weights_vec_len / 2 {
309+
inversed_domain_points_vanishing_evaluated_at_point[i]
310+
* (weights_first_half * coset_vanishing_evaluated_at_point)
311+
} else {
312+
inversed_domain_points_vanishing_evaluated_at_point[i]
313+
* (weights_second_half * coset_vanishing_evaluated_at_point)
314+
}
315+
})
316+
.collect();
317+
318+
weights
319+
}
320+
321+
fn simd_barycentric_eval_at_point(
322+
evals: &CircleEvaluation<Self, BaseField, BitReversedOrder>,
323+
point: CirclePoint<SecureField>,
324+
weights: &Col<Self, SecureField>,
325+
) -> SecureField {
326+
for i in 0..evals.domain.size() {
327+
if point == evals.domain.at(i).into_ef() {
328+
return evals
329+
.values
330+
.at(bit_reverse_index(i, evals.domain.log_size()))
331+
.into();
332+
}
333+
}
334+
let evals = evals.clone().bit_reverse();
335+
if evals.domain.size() < N_LANES {
336+
return (0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
337+
acc + (weights.at(i) * evals.values.at(i))
338+
});
339+
}
340+
(0..evals.domain.size().div_ceil(N_LANES))
341+
.fold(PackedSecureField::zero(), |acc, i| {
342+
acc + (weights.data[i] * evals.values.data[i])
343+
})
344+
.pointwise_sum()
345+
}
346+
236347
#[cfg(test)]
237348
mod tests {
238349
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
350+
use crate::core::backend::simd::column::BaseColumn;
351+
use crate::core::backend::simd::SimdBackend;
239352
use crate::core::backend::CpuBackend;
240353
use crate::core::circle::{CirclePoint, Coset};
241354
use crate::core::fields::m31::BaseField;
242-
use crate::core::poly::circle::CanonicCoset;
355+
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
243356
use crate::core::poly::{BitReversedOrder, NaturalOrder};
244357
use crate::m31;
245358

@@ -285,7 +398,7 @@ mod tests {
285398
}
286399

287400
#[test]
288-
fn test_barycentric_evaluation() {
401+
fn test_cpu_barycentric_evaluation() {
289402
let poly = CpuCirclePoly::new(
290403
[691, 805673, 5, 435684, 4832, 23876431, 197, 897346068]
291404
.map(BaseField::from)
@@ -310,9 +423,56 @@ mod tests {
310423
let sampled_barycentric_values = sampled_points
311424
.iter()
312425
.map(|point| {
313-
eval.barycentric_eval_at_point(
426+
super::barycentric_eval_at_point(
427+
eval,
428+
*point,
429+
&super::weights(
430+
eval.domain.log_size(),
431+
*point,
432+
),
433+
)
434+
})
435+
.collect::<Vec<_>>();
436+
437+
assert_eq!(
438+
sampled_barycentric_values, sampled_values,
439+
"Barycentric evaluation should be equal to the polynomial evaluation"
440+
);
441+
}
442+
443+
#[test]
444+
fn test_simd_barycentric_evaluation() {
445+
let poly = CirclePoly::<SimdBackend>::new(BaseColumn::from_cpu(
446+
[691, 805673, 5, 435684, 4832, 23876431, 197, 897346068]
447+
.map(BaseField::from)
448+
.to_vec(),
449+
));
450+
let s = CanonicCoset::new(3);
451+
let domain = s.circle_domain();
452+
let eval = poly.evaluate(domain);
453+
let sampled_points = [
454+
CirclePoint::get_point(348),
455+
CirclePoint::get_point(9736524),
456+
CirclePoint::get_point(13),
457+
CirclePoint::get_point(346752),
458+
domain.at(0).into_ef(),
459+
domain.at(3).into_ef(),
460+
];
461+
let sampled_values = sampled_points
462+
.iter()
463+
.map(|point| poly.eval_at_point(*point))
464+
.collect::<Vec<_>>();
465+
466+
let sampled_barycentric_values = sampled_points
467+
.iter()
468+
.map(|point| {
469+
super::simd_barycentric_eval_at_point(
470+
eval,
314471
*point,
315-
&super::weights(eval.domain.log_size(), *point),
472+
&super::simd_weights(
473+
eval.domain.log_size(),
474+
*point,
475+
),
316476
)
317477
})
318478
.collect::<Vec<_>>();

0 commit comments

Comments
 (0)