Skip to content

Commit 98d1ff3

Browse files
Barycentric Evaluation in Poly Ops
1 parent 88e1c1c commit 98d1ff3

File tree

4 files changed

+211
-167
lines changed

4 files changed

+211
-167
lines changed

crates/prover/src/core/backend/cpu/circle.rs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1+
use itertools::Itertools;
12
use num_traits::Zero;
23

34
use super::CpuBackend;
45
use crate::core::backend::cpu::bit_reverse;
5-
use crate::core::circle::{CirclePoint, Coset};
6+
use crate::core::backend::Col;
7+
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
8+
use crate::core::constraints::{coset_vanishing, coset_vanishing_derivative, point_vanishing};
69
use crate::core::fft::{butterfly, ibutterfly};
710
use crate::core::fields::m31::BaseField;
811
use crate::core::fields::qm31::SecureField;
912
use crate::core::fields::{batch_inverse_in_place, ExtensionOf};
10-
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps};
13+
use crate::core::poly::circle::{
14+
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
15+
};
1116
use crate::core::poly::twiddles::TwiddleTree;
1217
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
1318
use crate::core::poly::BitReversedOrder;
19+
use crate::core::utils::bit_reverse_index;
1420

1521
impl PolyOps for CpuBackend {
1622
type Twiddles = Vec<BaseField>;
@@ -86,6 +92,47 @@ impl PolyOps for CpuBackend {
8692
fold(&poly.coeffs, &mappings)
8793
}
8894

95+
fn barycentric_weights(
96+
coset: CanonicCoset,
97+
p: CirclePoint<SecureField>,
98+
) -> Col<CpuBackend, SecureField> {
99+
let domain = coset.circle_domain();
100+
101+
let (si_i, vi_p): (Vec<_>, Vec<_>) = (0..domain.size())
102+
.map(|i| {
103+
let coset_point = domain.at(i).into_ef::<SecureField>();
104+
let minus_two_coset_point_y = coset_point.y * SecureField::from(-2);
105+
(
106+
minus_two_coset_point_y
107+
* coset_vanishing_derivative(
108+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
109+
coset_point,
110+
),
111+
point_vanishing(coset_point, p.into_ef::<SecureField>()),
112+
)
113+
})
114+
.unzip();
115+
116+
let vn_p: SecureField = coset_vanishing(
117+
CanonicCoset::new(domain.log_size()).coset,
118+
p.into_ef::<SecureField>(),
119+
);
120+
121+
// TODO(Gali): Change weights order to bit-reverse order.
122+
(0..domain.size())
123+
.map(|i| vn_p / (si_i[i] * vi_p[i]))
124+
.collect_vec()
125+
}
126+
127+
fn barycentric_eval_at_point(
128+
evals: &CircleEvaluation<CpuBackend, BaseField, BitReversedOrder>,
129+
weights: &Col<CpuBackend, SecureField>,
130+
) -> SecureField {
131+
(0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
132+
acc + (evals.values[bit_reverse_index(i, evals.domain.log_size())] * weights[i])
133+
})
134+
}
135+
89136
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
90137
assert!(log_size >= poly.log_size());
91138
let mut coeffs = Vec::with_capacity(1 << log_size);

crates/prover/src/core/backend/simd/circle.rs

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::mem::transmute;
33
use std::simd::Simd;
44

55
use bytemuck::Zeroable;
6+
use itertools::Itertools;
7+
use num_traits::{One, Zero};
68
#[cfg(feature = "parallel")]
79
use rayon::prelude::*;
810

@@ -14,10 +16,11 @@ use crate::core::backend::cpu::circle::slow_precompute_twiddles;
1416
use crate::core::backend::simd::column::BaseColumn;
1517
use crate::core::backend::simd::m31::PackedM31;
1618
use crate::core::backend::{Col, Column, CpuBackend};
17-
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
19+
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset, M31_CIRCLE_LOG_ORDER};
20+
use crate::core::constraints::{coset_vanishing, coset_vanishing_derivative, point_vanishing};
1821
use crate::core::fields::m31::BaseField;
1922
use crate::core::fields::qm31::SecureField;
20-
use crate::core::fields::{Field, FieldExpOps};
23+
use crate::core::fields::{batch_inverse_in_place, Field, FieldExpOps};
2124
use crate::core::poly::circle::{
2225
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
2326
};
@@ -221,6 +224,84 @@ impl PolyOps for SimdBackend {
221224
(sum * twiddle_lows).pointwise_sum()
222225
}
223226

227+
fn barycentric_weights(
228+
coset: CanonicCoset,
229+
p: CirclePoint<SecureField>,
230+
) -> Col<SimdBackend, SecureField> {
231+
let domain = coset.circle_domain();
232+
let weights_vec_len = domain.size().div_ceil(N_LANES);
233+
if weights_vec_len == 1 {
234+
return Col::<SimdBackend, SecureField>::from_iter(CircleEvaluation::<
235+
CpuBackend,
236+
BaseField,
237+
BitReversedOrder,
238+
>::barycentric_weights(
239+
coset, p
240+
));
241+
}
242+
243+
let p_0 = domain.at(0).into_ef::<SecureField>();
244+
let si_0 = SecureField::one()
245+
/ ((p_0.y * SecureField::from(-2))
246+
* coset_vanishing_derivative(
247+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
248+
p_0,
249+
));
250+
251+
// TODO(Gali): Optimize to a batched point_vanishing()
252+
let vi_p = (0..weights_vec_len)
253+
.map(|i| {
254+
PackedSecureField::from_array(std::array::from_fn(|j| {
255+
point_vanishing(
256+
domain.at(i * N_LANES + j).into_ef::<SecureField>(),
257+
p.into_ef::<SecureField>(),
258+
)
259+
}))
260+
})
261+
.collect_vec();
262+
let mut vi_p_inverse = Vec::with_capacity(weights_vec_len);
263+
#[allow(clippy::uninit_vec)]
264+
unsafe {
265+
vi_p_inverse.set_len(weights_vec_len)
266+
};
267+
batch_inverse_in_place(&vi_p, &mut vi_p_inverse);
268+
269+
let vn_p: SecureField = coset_vanishing(
270+
CanonicCoset::new(domain.log_size()).coset,
271+
p.into_ef::<SecureField>(),
272+
);
273+
274+
let si_0_vn_p = PackedSecureField::broadcast(si_0 * vn_p);
275+
276+
// TODO(Gali): Change weights order to bit-reverse order.
277+
// S_i(i) is invariant under G_(n−1) and alternate under J, meaning the S_i(i) values are
278+
// the same for each half coset, and the second half coset values are the conjugate
279+
// of the first half coset values.
280+
let weights: Col<SimdBackend, SecureField> = (0..weights_vec_len)
281+
.map(|i| {
282+
if i < weights_vec_len / 2 {
283+
vi_p_inverse[i] * si_0_vn_p
284+
} else {
285+
vi_p_inverse[i] * -si_0_vn_p
286+
}
287+
})
288+
.collect();
289+
290+
weights
291+
}
292+
293+
fn barycentric_eval_at_point(
294+
evals: &CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>,
295+
weights: &Col<SimdBackend, SecureField>,
296+
) -> SecureField {
297+
let evals = evals.clone().bit_reverse();
298+
(0..evals.domain.size().div_ceil(N_LANES))
299+
.fold(PackedSecureField::zero(), |acc, i| {
300+
acc + (weights.data[i] * evals.values.data[i])
301+
})
302+
.pointwise_sum()
303+
}
304+
224305
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
225306
// TODO(shahars): Get rid of extends.
226307
poly.evaluate(CanonicCoset::new(log_size).circle_domain())

0 commit comments

Comments
 (0)