Skip to content

Commit faab86f

Browse files
SIMD Barycentric Evaluation and Test
1 parent ad1902f commit faab86f

File tree

1 file changed

+142
-4
lines changed

1 file changed

+142
-4
lines changed

crates/prover/src/core/poly/circle/evaluation.rs

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@ use std::ops::{Deref, Index};
33

44
use educe::Educe;
55
use itertools::Itertools;
6-
use num_traits::Zero;
6+
use num_traits::{One, Zero};
77

88
use super::{CircleDomain, CirclePoly, PolyOps};
99
use crate::core::backend::cpu::CpuCircleEvaluation;
10+
use crate::core::backend::simd::m31::N_LANES;
11+
use crate::core::backend::simd::qm31::PackedSecureField;
1012
use crate::core::backend::simd::SimdBackend;
1113
use crate::core::backend::{Col, Column, ColumnOps, CpuBackend};
1214
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
1315
use crate::core::constraints::{coset_vanishing, coset_vanishing_derivative, point_vanishing};
1416
use crate::core::fields::m31::BaseField;
1517
use crate::core::fields::qm31::SecureField;
16-
use crate::core::fields::ExtensionOf;
18+
use crate::core::fields::{batch_inverse_in_place, ExtensionOf};
1719
use crate::core::poly::circle::CanonicCoset;
1820
use crate::core::poly::twiddles::TwiddleTree;
1921
use crate::core::poly::{BitReversedOrder, NaturalOrder};
@@ -229,13 +231,115 @@ fn barycentric_eval_at_point(
229231
})
230232
}
231233

234+
// TODO(Gali): Remove.
235+
#[allow(dead_code)]
236+
/// Computes the weights for Barycentric Lagrange interpolation for point `p` on `coset`.
237+
/// `p` must not be in the domain. For more information, see [`barycentric_weights`].
238+
fn simd_barycentric_weights(
239+
coset: CanonicCoset,
240+
p: CirclePoint<SecureField>,
241+
) -> Col<SimdBackend, SecureField> {
242+
let domain = coset.circle_domain();
243+
let weights_vec_len = domain.size().div_ceil(N_LANES);
244+
245+
// S_i(i) is invariant under G_(n−1) and alternate under J, so we can calculate only 2 values
246+
let p_first_half_coset = domain.at(0).into_ef::<SecureField>();
247+
let si_i_first_half = SecureField::one()
248+
/ ((p_first_half_coset.y * SecureField::from(-2))
249+
* coset_vanishing_derivative(
250+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
251+
p_first_half_coset,
252+
));
253+
let p_second_half_coset = domain.at(domain.half_coset.size()).into_ef::<SecureField>();
254+
let si_i_second_half = SecureField::one()
255+
/ ((p_second_half_coset.y * SecureField::from(-2))
256+
* coset_vanishing_derivative(
257+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
258+
p_second_half_coset,
259+
));
260+
261+
// TODO(Gali): Optimize to a batched point_vanishing()
262+
let vi_p = (0..weights_vec_len)
263+
.map(|i| {
264+
PackedSecureField::from_array(std::array::from_fn(|j| {
265+
if domain.size() <= i * N_LANES + j {
266+
SecureField::one()
267+
} else {
268+
point_vanishing(
269+
domain.at(i * N_LANES + j).into_ef::<SecureField>(),
270+
p.into_ef::<SecureField>(),
271+
)
272+
}
273+
}))
274+
})
275+
.collect_vec();
276+
let mut vi_p_inverse = vec![unsafe { std::mem::zeroed() }; weights_vec_len];
277+
278+
batch_inverse_in_place(&vi_p, &mut vi_p_inverse);
279+
280+
let vn_p: SecureField = coset_vanishing(
281+
CanonicCoset::new(domain.log_size()).coset,
282+
p.into_ef::<SecureField>(),
283+
);
284+
285+
// TODO(Gali): Change weights order to bit-reverse order.
286+
if weights_vec_len == 1 {
287+
return (0..N_LANES)
288+
.map(|i| {
289+
let vi_p_inverse = vi_p_inverse[0].to_array();
290+
if i < domain.size() / 2 {
291+
vi_p_inverse[i] * (si_i_first_half * vn_p)
292+
} else {
293+
vi_p_inverse[i] * (si_i_second_half * vn_p)
294+
}
295+
})
296+
.collect();
297+
}
298+
299+
let weights: Col<SimdBackend, SecureField> = (0..weights_vec_len)
300+
.map(|i| {
301+
if i < weights_vec_len / 2 {
302+
vi_p_inverse[i] * (si_i_first_half * vn_p)
303+
} else {
304+
vi_p_inverse[i] * (si_i_second_half * vn_p)
305+
}
306+
})
307+
.collect();
308+
309+
weights
310+
}
311+
312+
// TODO(Gali): Remove.
313+
#[allow(dead_code)]
314+
/// Evaluates a polynomial at a point using the barycentric interpolation formula,
315+
/// given its evaluations on a circle domain and precomputed barycentric weights for the domain
316+
/// at the sampled point. For more information, see [`barycentric_eval_at_point`]
317+
fn simd_barycentric_eval_at_point(
318+
evals: &CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>,
319+
weights: &Col<SimdBackend, SecureField>,
320+
) -> SecureField {
321+
let evals = evals.clone().bit_reverse();
322+
if evals.domain.size() < N_LANES {
323+
return (0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
324+
acc + (weights.at(i) * evals.values.at(i))
325+
});
326+
}
327+
(0..evals.domain.size().div_ceil(N_LANES))
328+
.fold(PackedSecureField::zero(), |acc, i| {
329+
acc + (weights.data[i] * evals.values.data[i])
330+
})
331+
.pointwise_sum()
332+
}
333+
232334
#[cfg(test)]
233335
mod tests {
234336
use super::*;
235337
use crate::core::backend::cpu::{CpuCircleEvaluation, CpuCirclePoly};
338+
use crate::core::backend::simd::column::BaseColumn;
339+
use crate::core::backend::simd::SimdBackend;
236340
use crate::core::circle::{CirclePoint, Coset};
237341
use crate::core::fields::m31::BaseField;
238-
use crate::core::poly::circle::CanonicCoset;
342+
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
239343
use crate::core::poly::NaturalOrder;
240344
use crate::m31;
241345

@@ -281,7 +385,7 @@ mod tests {
281385
}
282386

283387
#[test]
284-
fn test_barycentric_evaluation() {
388+
fn test_cpu_barycentric_evaluation() {
285389
let poly = CpuCirclePoly::new(
286390
[691, 805673, 5, 435684, 4832, 23876431, 197, 897346068]
287391
.map(BaseField::from)
@@ -311,4 +415,38 @@ mod tests {
311415
"Barycentric evaluation should be equal to the polynomial evaluation"
312416
);
313417
}
418+
419+
#[test]
420+
fn test_simd_barycentric_evaluation() {
421+
let poly = CirclePoly::<SimdBackend>::new(BaseColumn::from_cpu(
422+
[691, 805673, 5, 435684, 4832, 23876431, 197, 897346068]
423+
.map(BaseField::from)
424+
.to_vec(),
425+
));
426+
let s = CanonicCoset::new(10);
427+
let domain = s.circle_domain();
428+
let eval = poly.evaluate(domain);
429+
let sampled_points = [
430+
CirclePoint::get_point(348),
431+
CirclePoint::get_point(9736524),
432+
CirclePoint::get_point(13),
433+
CirclePoint::get_point(346752),
434+
];
435+
let sampled_values = sampled_points
436+
.iter()
437+
.map(|point| poly.eval_at_point(*point))
438+
.collect_vec();
439+
440+
let sampled_barycentric_values = sampled_points
441+
.iter()
442+
.map(|point| {
443+
simd_barycentric_eval_at_point(&eval, &simd_barycentric_weights(s, *point))
444+
})
445+
.collect_vec();
446+
447+
assert_eq!(
448+
sampled_barycentric_values, sampled_values,
449+
"Barycentric evaluation should be equal to the polynomial evaluation"
450+
);
451+
}
314452
}

0 commit comments

Comments
 (0)