Skip to content

Commit 209a21d

Browse files
SIMD Barycentric Evaluation and Test
1 parent ad1902f commit 209a21d

File tree

1 file changed

+162
-4
lines changed

1 file changed

+162
-4
lines changed

crates/prover/src/core/poly/circle/evaluation.rs

Lines changed: 162 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@ use std::ops::{Deref, Index};
33

44
use educe::Educe;
55
use itertools::Itertools;
6-
use num_traits::Zero;
6+
use num_traits::{One, Zero};
77

88
use super::{CircleDomain, CirclePoly, PolyOps};
99
use crate::core::backend::cpu::CpuCircleEvaluation;
10+
use crate::core::backend::simd::m31::N_LANES;
11+
use crate::core::backend::simd::qm31::PackedSecureField;
1012
use crate::core::backend::simd::SimdBackend;
1113
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
1214
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
1315
use crate::core::constraints::{coset_vanishing, coset_vanishing_derivative, point_vanishing};
1416
use crate::core::fields::m31::BaseField;
1517
use crate::core::fields::qm31::SecureField;
16-
use crate::core::fields::ExtensionOf;
18+
use crate::core::fields::{batch_inverse_in_place, ExtensionOf};
1719
use crate::core::poly::circle::CanonicCoset;
1820
use crate::core::poly::twiddles::TwiddleTree;
1921
use crate::core::poly::{BitReversedOrder, NaturalOrder};
@@ -229,13 +231,108 @@ fn barycentric_eval_at_point(
229231
})
230232
}
231233

234+
// TODO(Gali): Remove.
235+
#[allow(dead_code)]
236+
/// Computes the weights for Barycentric Lagrange interpolation for point `p` on `coset`.
237+
/// `p` must not be in the domain. For more information, see [`barycentric_weights`].
238+
fn simd_barycentric_weights(
239+
coset: CanonicCoset,
240+
p: CirclePoint<SecureField>,
241+
) -> Col<SimdBackend, SecureField> {
242+
let domain = coset.circle_domain();
243+
let weights_vec_len = domain.size().div_ceil(N_LANES);
244+
245+
// S_i(i) is invariant under G_(n−1) and alternate under J, so we can calculate only 2 values
246+
let p_first_half_coset = domain.at(0).into_ef::<SecureField>();
247+
let si_i_first_half = SecureField::one()
248+
/ ((p_first_half_coset.y * SecureField::from(-2))
249+
* coset_vanishing_derivative(
250+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
251+
p_first_half_coset,
252+
));
253+
let si_i_second_half = -si_i_first_half;
254+
255+
// TODO(Gali): Optimize to a batched point_vanishing()
256+
let vi_p = (0..weights_vec_len)
257+
.map(|i| {
258+
PackedSecureField::from_array(std::array::from_fn(|j| {
259+
if domain.size() <= i * N_LANES + j {
260+
SecureField::one()
261+
} else {
262+
point_vanishing(
263+
domain.at(i * N_LANES + j).into_ef::<SecureField>(),
264+
p.into_ef::<SecureField>(),
265+
)
266+
}
267+
}))
268+
})
269+
.collect_vec();
270+
let mut vi_p_inverse = vec![unsafe { std::mem::zeroed() }; weights_vec_len];
271+
272+
batch_inverse_in_place(&vi_p, &mut vi_p_inverse);
273+
274+
let vn_p: SecureField = coset_vanishing(
275+
CanonicCoset::new(domain.log_size()).coset,
276+
p.into_ef::<SecureField>(),
277+
);
278+
279+
// TODO(Gali): Change weights order to bit-reverse order.
280+
if weights_vec_len == 1 {
281+
return (0..N_LANES)
282+
.map(|i| {
283+
if i >= domain.size() {
284+
SecureField::zero()
285+
} else {
286+
let vi_p_inverse = vi_p_inverse[0].to_array();
287+
if i < domain.size() / 2 {
288+
vi_p_inverse[i] * (si_i_first_half * vn_p)
289+
} else {
290+
vi_p_inverse[i] * (si_i_second_half * vn_p)
291+
}
292+
}
293+
})
294+
.collect();
295+
}
296+
297+
let weights: Col<SimdBackend, SecureField> = (0..weights_vec_len)
298+
.map(|i| {
299+
if i < weights_vec_len / 2 {
300+
vi_p_inverse[i] * (si_i_first_half * vn_p)
301+
} else {
302+
vi_p_inverse[i] * (si_i_second_half * vn_p)
303+
}
304+
})
305+
.collect();
306+
307+
weights
308+
}
309+
310+
// TODO(Gali): Remove.
311+
#[allow(dead_code)]
312+
/// Evaluates a polynomial at a point using the barycentric interpolation formula,
313+
/// given its evaluations on a circle domain and precomputed barycentric weights for the domain
314+
/// at the sampled point. For more information, see [`barycentric_eval_at_point`]
315+
fn simd_barycentric_eval_at_point(
316+
evals: &CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>,
317+
weights: &Col<SimdBackend, SecureField>,
318+
) -> SecureField {
319+
let evals = evals.clone().bit_reverse();
320+
(0..evals.domain.size().div_ceil(N_LANES))
321+
.fold(PackedSecureField::zero(), |acc, i| {
322+
acc + (weights.data[i] * evals.values.data[i])
323+
})
324+
.pointwise_sum()
325+
}
326+
232327
#[cfg(test)]
233328
mod tests {
234329
use super::*;
235330
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
331+
use crate::core::backend::simd::column::BaseColumn;
332+
use crate::core::backend::simd::SimdBackend;
236333
use crate::core::circle::{CirclePoint, Coset};
237334
use crate::core::fields::m31::BaseField;
238-
use crate::core::poly::circle::CanonicCoset;
335+
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
239336
use crate::core::poly::NaturalOrder;
240337
use crate::m31;
241338

@@ -281,7 +378,7 @@ mod tests {
281378
}
282379

283380
#[test]
284-
fn test_barycentric_evaluation() {
381+
fn test_cpu_barycentric_evaluation() {
285382
let poly = CpuCirclePoly::new(
286383
[691, 805673, 5, 435684, 4832, 23876431, 197, 897346068]
287384
.map(BaseField::from)
@@ -311,4 +408,65 @@ mod tests {
311408
"Barycentric evaluation should be equal to the polynomial evaluation"
312409
);
313410
}
411+
412+
#[test]
413+
fn test_simd_barycentric_evaluation() {
414+
let poly = CirclePoly::<SimdBackend>::new(BaseColumn::from_cpu(
415+
[691, 805673, 5, 435684, 4832, 23876431, 197, 897346068]
416+
.map(BaseField::from)
417+
.to_vec(),
418+
));
419+
let s = CanonicCoset::new(10);
420+
let domain = s.circle_domain();
421+
let eval = poly.evaluate(domain);
422+
let sampled_points = [
423+
CirclePoint::get_point(348),
424+
CirclePoint::get_point(9736524),
425+
CirclePoint::get_point(13),
426+
CirclePoint::get_point(346752),
427+
];
428+
let sampled_values = sampled_points
429+
.iter()
430+
.map(|point| poly.eval_at_point(*point))
431+
.collect_vec();
432+
433+
let sampled_barycentric_values = sampled_points
434+
.iter()
435+
.map(|point| {
436+
simd_barycentric_eval_at_point(&eval, &simd_barycentric_weights(s, *point))
437+
})
438+
.collect_vec();
439+
440+
assert_eq!(
441+
sampled_barycentric_values, sampled_values,
442+
"Barycentric evaluation should be equal to the polynomial evaluation"
443+
);
444+
}
445+
446+
#[test]
447+
fn test_simd_barycentric_weights() {
448+
let s = CanonicCoset::new(10);
449+
let sampled_points = [
450+
CirclePoint::get_point(348),
451+
CirclePoint::get_point(9736524),
452+
CirclePoint::get_point(13),
453+
CirclePoint::get_point(346752),
454+
];
455+
456+
let cpu_barycentric_values = sampled_points
457+
.iter()
458+
.map(|point| barycentric_weights(s, *point))
459+
.collect_vec();
460+
let simd_barycentric_values = sampled_points
461+
.iter()
462+
.map(|point| simd_barycentric_weights(s, *point))
463+
.collect_vec();
464+
465+
cpu_barycentric_values
466+
.iter()
467+
.zip(simd_barycentric_values.iter())
468+
.for_each(|(cpu_weights, simd_weights)| {
469+
assert_eq!(*cpu_weights, simd_weights.to_cpu());
470+
});
471+
}
314472
}

0 commit comments

Comments
 (0)