Skip to content

Commit 9c44f04

Browse files
Barycentric Evaluation in Poly Ops
1 parent baac125 commit 9c44f04

File tree

4 files changed

+240
-219
lines changed

4 files changed

+240
-219
lines changed

crates/prover/src/core/backend/cpu/circle.rs

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
use num_traits::Zero;
1+
use itertools::Itertools;
2+
use num_traits::{One, Zero};
23

34
use super::CpuBackend;
45
use crate::core::backend::cpu::bit_reverse;
5-
use crate::core::circle::{CirclePoint, Coset};
6+
use crate::core::backend::Col;
7+
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset};
8+
use crate::core::constraints::{coset_vanishing, coset_vanishing_derivative, point_vanishing};
69
use crate::core::fft::{butterfly, ibutterfly};
710
use crate::core::fields::m31::BaseField;
811
use crate::core::fields::qm31::SecureField;
912
use crate::core::fields::{batch_inverse_in_place, ExtensionOf};
10-
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, CirclePoly, PolyOps};
13+
use crate::core::poly::circle::{
14+
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
15+
};
1116
use crate::core::poly::twiddles::TwiddleTree;
1217
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
1318
use crate::core::poly::BitReversedOrder;
19+
use crate::core::utils::bit_reverse_index;
1420

1521
impl PolyOps for CpuBackend {
1622
type Twiddles = Vec<BaseField>;
@@ -86,6 +92,84 @@ impl PolyOps for CpuBackend {
8692
fold(&poly.coeffs, &mappings)
8793
}
8894

95+
fn weights(log_size: u32, sample_point: CirclePoint<SecureField>) -> Col<Self, SecureField> {
96+
let domain = CanonicCoset::new(log_size).circle_domain();
97+
98+
for i in 0..domain.size() {
99+
if domain.at(i).into_ef() == sample_point {
100+
let mut weights = vec![SecureField::zero(); domain.size()];
101+
weights[i] = SecureField::one();
102+
return weights;
103+
}
104+
}
105+
106+
let p_0 = domain.at(0).into_ef::<SecureField>();
107+
let weights_first_half = SecureField::one()
108+
/ (-(p_0.y + p_0.y)
109+
* coset_vanishing_derivative(
110+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
111+
p_0,
112+
));
113+
let p_0_inverse = domain.at(domain.half_coset.size()).into_ef::<SecureField>();
114+
let weights_second_half = SecureField::one()
115+
/ (-(p_0_inverse.y + p_0_inverse.y)
116+
* coset_vanishing_derivative(
117+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
118+
p_0_inverse,
119+
));
120+
121+
let domain_points_vanishing_evaluated_at_point = (0..domain.size())
122+
.map(|i| {
123+
point_vanishing(
124+
domain.at(i).into_ef::<SecureField>(),
125+
sample_point.into_ef::<SecureField>(),
126+
)
127+
})
128+
.collect_vec();
129+
let mut inversed_domain_points_vanishing_evaluated_at_point =
130+
vec![unsafe { std::mem::zeroed() }; domain.size()];
131+
132+
batch_inverse_in_place(
133+
&domain_points_vanishing_evaluated_at_point,
134+
&mut inversed_domain_points_vanishing_evaluated_at_point,
135+
);
136+
137+
let coset_vanishing_evaluated_at_point: SecureField = coset_vanishing(
138+
CanonicCoset::new(domain.log_size()).coset,
139+
sample_point.into_ef::<SecureField>(),
140+
);
141+
142+
(0..domain.size())
143+
.map(|i| {
144+
if i < domain.half_coset.size() {
145+
weights_first_half
146+
* inversed_domain_points_vanishing_evaluated_at_point[i]
147+
* coset_vanishing_evaluated_at_point
148+
} else {
149+
weights_second_half
150+
* inversed_domain_points_vanishing_evaluated_at_point[i]
151+
* coset_vanishing_evaluated_at_point
152+
}
153+
})
154+
.collect_vec()
155+
}
156+
157+
fn barycentric_eval_at_point(
158+
evals: &CircleEvaluation<Self, BaseField, BitReversedOrder>,
159+
point: CirclePoint<SecureField>,
160+
weights: &Col<Self, SecureField>,
161+
) -> SecureField {
162+
for i in 0..evals.domain.size() {
163+
if point == evals.domain.at(i).into_ef() {
164+
return evals.values[bit_reverse_index(i, evals.domain.log_size())].into();
165+
}
166+
}
167+
168+
(0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
169+
acc + (evals.values[bit_reverse_index(i, evals.domain.log_size())] * weights[i])
170+
})
171+
}
172+
89173
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
90174
assert!(log_size >= poly.log_size());
91175
let mut coeffs = Vec::with_capacity(1 << log_size);

crates/prover/src/core/backend/simd/circle.rs

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::mem::transmute;
33
use std::simd::Simd;
44

55
use bytemuck::Zeroable;
6+
use itertools::Itertools;
7+
use num_traits::{One, Zero};
68
#[cfg(feature = "parallel")]
79
use rayon::prelude::*;
810

@@ -14,10 +16,11 @@ use crate::core::backend::cpu::circle::slow_precompute_twiddles;
1416
use crate::core::backend::simd::column::BaseColumn;
1517
use crate::core::backend::simd::m31::PackedM31;
1618
use crate::core::backend::{Col, Column, CpuBackend};
17-
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
19+
use crate::core::circle::{CirclePoint, CirclePointIndex, Coset, M31_CIRCLE_LOG_ORDER};
20+
use crate::core::constraints::{coset_vanishing, coset_vanishing_derivative, point_vanishing};
1821
use crate::core::fields::m31::BaseField;
1922
use crate::core::fields::qm31::SecureField;
20-
use crate::core::fields::{Field, FieldExpOps};
23+
use crate::core::fields::{batch_inverse_in_place, Field, FieldExpOps};
2124
use crate::core::poly::circle::{
2225
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
2326
};
@@ -221,6 +224,117 @@ impl PolyOps for SimdBackend {
221224
(sum * twiddle_lows).pointwise_sum()
222225
}
223226

227+
fn weights(log_size: u32, sample_point: CirclePoint<SecureField>) -> Col<Self, SecureField> {
228+
let domain = CanonicCoset::new(log_size).circle_domain();
229+
let weights_vec_len = domain.size().div_ceil(N_LANES);
230+
231+
for i in 0..domain.size() {
232+
if domain.at(i).into_ef() == sample_point {
233+
let mut weights = Col::<Self, SecureField>::zeros(domain.size());
234+
weights.set(i, SecureField::one());
235+
return weights;
236+
}
237+
}
238+
239+
let p_0 = domain.at(0).into_ef::<SecureField>();
240+
let weights_first_half = SecureField::one()
241+
/ (-(p_0.y + p_0.y)
242+
* coset_vanishing_derivative(
243+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
244+
p_0,
245+
));
246+
let p_0_inverse = domain.at(domain.half_coset.size()).into_ef::<SecureField>();
247+
let weights_second_half = SecureField::one()
248+
/ (-(p_0_inverse.y + p_0_inverse.y)
249+
* coset_vanishing_derivative(
250+
Coset::new(CirclePointIndex::generator(), domain.log_size()),
251+
p_0_inverse,
252+
));
253+
254+
let domain_points_vanishing_evaluated_at_point = (0..weights_vec_len)
255+
.map(|i| {
256+
PackedSecureField::from_array(std::array::from_fn(|j| {
257+
if domain.size() <= i * N_LANES + j {
258+
SecureField::one()
259+
} else {
260+
point_vanishing(
261+
domain.at(i * N_LANES + j).into_ef::<SecureField>(),
262+
sample_point.into_ef::<SecureField>(),
263+
)
264+
}
265+
}))
266+
})
267+
.collect_vec();
268+
let mut inversed_domain_points_vanishing_evaluated_at_point =
269+
vec![unsafe { std::mem::zeroed() }; weights_vec_len];
270+
271+
batch_inverse_in_place(
272+
&domain_points_vanishing_evaluated_at_point,
273+
&mut inversed_domain_points_vanishing_evaluated_at_point,
274+
);
275+
276+
let coset_vanishing_evaluated_at_point: SecureField = coset_vanishing(
277+
CanonicCoset::new(domain.log_size()).coset,
278+
sample_point.into_ef::<SecureField>(),
279+
);
280+
281+
if weights_vec_len == 1 {
282+
return (0..N_LANES)
283+
.map(|i| {
284+
let inversed_domain_points_vanishing_evaluated_at_point =
285+
inversed_domain_points_vanishing_evaluated_at_point[0].to_array();
286+
if i < domain.size() / 2 {
287+
inversed_domain_points_vanishing_evaluated_at_point[i]
288+
* (weights_first_half * coset_vanishing_evaluated_at_point)
289+
} else {
290+
inversed_domain_points_vanishing_evaluated_at_point[i]
291+
* (weights_second_half * coset_vanishing_evaluated_at_point)
292+
}
293+
})
294+
.collect();
295+
}
296+
297+
let weights: Col<Self, SecureField> = (0..weights_vec_len)
298+
.map(|i| {
299+
if i < weights_vec_len / 2 {
300+
inversed_domain_points_vanishing_evaluated_at_point[i]
301+
* (weights_first_half * coset_vanishing_evaluated_at_point)
302+
} else {
303+
inversed_domain_points_vanishing_evaluated_at_point[i]
304+
* (weights_second_half * coset_vanishing_evaluated_at_point)
305+
}
306+
})
307+
.collect();
308+
309+
weights
310+
}
311+
312+
fn barycentric_eval_at_point(
313+
evals: &CircleEvaluation<Self, BaseField, BitReversedOrder>,
314+
point: CirclePoint<SecureField>,
315+
weights: &Col<Self, SecureField>,
316+
) -> SecureField {
317+
for i in 0..evals.domain.size() {
318+
if point == evals.domain.at(i).into_ef() {
319+
return evals
320+
.values
321+
.at(bit_reverse_index(i, evals.domain.log_size()))
322+
.into();
323+
}
324+
}
325+
let evals = evals.clone().bit_reverse();
326+
if evals.domain.size() < N_LANES {
327+
return (0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
328+
acc + (weights.at(i) * evals.values.at(i))
329+
});
330+
}
331+
(0..evals.domain.size().div_ceil(N_LANES))
332+
.fold(PackedSecureField::zero(), |acc, i| {
333+
acc + (weights.data[i] * evals.values.data[i])
334+
})
335+
.pointwise_sum()
336+
}
337+
224338
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {
225339
// TODO(shahars): Get rid of extends.
226340
poly.evaluate(CanonicCoset::new(log_size).circle_domain())

0 commit comments

Comments
 (0)