Skip to content

Commit 57745f8

Browse files
Parallel Barycentric and Weights
1 parent c716759 commit 57745f8

File tree

2 files changed

+131
-38
lines changed

2 files changed

+131
-38
lines changed

crates/prover/src/core/backend/cpu/circle.rs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
#[cfg(not(feature = "parallel"))]
12
use itertools::Itertools;
23
use num_traits::{One, Zero};
4+
#[cfg(feature = "parallel")]
5+
use rayon::prelude::*;
36

47
use super::CpuBackend;
58
use crate::core::backend::cpu::bit_reverse;
@@ -122,6 +125,7 @@ impl PolyOps for CpuBackend {
122125
));
123126

124127
// TODO(Gali): Optimize to a batched point_vanishing()
128+
#[cfg(not(feature = "parallel"))]
125129
let domain_points_vanishing_evaluated_at_point = (0..domain.size())
126130
.map(|i| {
127131
point_vanishing(
@@ -132,6 +136,20 @@ impl PolyOps for CpuBackend {
132136
)
133137
})
134138
.collect_vec();
139+
140+
#[cfg(feature = "parallel")]
141+
let domain_points_vanishing_evaluated_at_point: Vec<_> = (0..domain.size())
142+
.into_par_iter()
143+
.map(|i| {
144+
point_vanishing(
145+
domain
146+
.at(bit_reverse_index(i, log_size))
147+
.into_ef::<SecureField>(),
148+
sample_point.into_ef::<SecureField>(),
149+
)
150+
})
151+
.collect();
152+
135153
let mut inversed_domain_points_vanishing_evaluated_at_point =
136154
vec![unsafe { std::mem::zeroed() }; domain.size()];
137155

@@ -145,7 +163,24 @@ impl PolyOps for CpuBackend {
145163
sample_point.into_ef::<SecureField>(),
146164
);
147165

148-
(0..domain.size())
166+
#[cfg(not(feature = "parallel"))]
167+
let weights = (0..domain.size())
168+
.map(|i| {
169+
if i % 2 == 0 {
170+
weights_first_half
171+
* inversed_domain_points_vanishing_evaluated_at_point[i]
172+
* coset_vanishing_evaluated_at_point
173+
} else {
174+
weights_second_half
175+
* inversed_domain_points_vanishing_evaluated_at_point[i]
176+
* coset_vanishing_evaluated_at_point
177+
}
178+
})
179+
.collect_vec();
180+
181+
#[cfg(feature = "parallel")]
182+
let weights = (0..domain.size())
183+
.into_par_iter()
149184
.map(|i| {
150185
if i % 2 == 0 {
151186
weights_first_half
@@ -157,7 +192,9 @@ impl PolyOps for CpuBackend {
157192
* coset_vanishing_evaluated_at_point
158193
}
159194
})
160-
.collect_vec()
195+
.collect();
196+
197+
weights
161198
}
162199

163200
fn barycentric_eval_at_point(
@@ -171,9 +208,18 @@ impl PolyOps for CpuBackend {
171208
}
172209
}
173210

174-
(0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
211+
#[cfg(not(feature = "parallel"))]
212+
return (0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
175213
acc + (evals.values[i] * weights[i])
176-
})
214+
});
215+
216+
#[cfg(feature = "parallel")]
217+
return (0..evals.domain.size())
218+
.into_par_iter()
219+
.fold(SecureField::zero, |acc: SecureField, i: usize| {
220+
acc + (evals.values[i] * weights[i])
221+
})
222+
.sum::<SecureField>();
177223
}
178224

179225
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {

crates/prover/src/core/backend/simd/circle.rs

Lines changed: 81 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::mem::transmute;
33
use std::simd::Simd;
44

55
use bytemuck::Zeroable;
6+
#[cfg(not(feature = "parallel"))]
67
use itertools::Itertools;
78
use num_traits::{One, Zero};
89
#[cfg(feature = "parallel")]
@@ -254,7 +255,17 @@ impl PolyOps for SimdBackend {
254255
p_0_inverse,
255256
));
256257

258+
let weights_half_and_half: PackedSecureField =
259+
PackedSecureField::from_array(std::array::from_fn(|i| {
260+
if i % 2 == 0 {
261+
weights_first_half
262+
} else {
263+
weights_second_half
264+
}
265+
}));
266+
257267
// TODO(Gali): Optimize to a batched point_vanishing()
268+
#[cfg(not(feature = "parallel"))]
258269
let domain_points_vanishing_evaluated_at_point = (0..weights_vec_len)
259270
.map(|i| {
260271
PackedSecureField::from_array(std::array::from_fn(|j| {
@@ -271,6 +282,26 @@ impl PolyOps for SimdBackend {
271282
}))
272283
})
273284
.collect_vec();
285+
#[cfg(feature = "parallel")]
286+
let domain_points_vanishing_evaluated_at_point: Vec<PackedSecureField> = (0
287+
..weights_vec_len)
288+
.into_par_iter()
289+
.map(|i| {
290+
PackedSecureField::from_array(std::array::from_fn(|j| {
291+
if domain.size() <= bit_reverse_index(i * N_LANES + j, log_size) {
292+
SecureField::one()
293+
} else {
294+
point_vanishing(
295+
domain
296+
.at(bit_reverse_index(i * N_LANES + j, log_size))
297+
.into_ef::<SecureField>(),
298+
sample_point.into_ef::<SecureField>(),
299+
)
300+
}
301+
}))
302+
})
303+
.collect();
304+
274305
let mut inversed_domain_points_vanishing_evaluated_at_point =
275306
vec![unsafe { std::mem::zeroed() }; weights_vec_len];
276307

@@ -279,44 +310,35 @@ impl PolyOps for SimdBackend {
279310
&mut inversed_domain_points_vanishing_evaluated_at_point,
280311
);
281312

282-
let coset_vanishing_evaluated_at_point: SecureField = coset_vanishing(
283-
CanonicCoset::new(domain.log_size()).coset,
284-
sample_point.into_ef::<SecureField>(),
285-
);
313+
let coset_vanishing_evaluated_at_point: PackedSecureField =
314+
PackedSecureField::broadcast(coset_vanishing(
315+
CanonicCoset::new(domain.log_size()).coset,
316+
sample_point.into_ef::<SecureField>(),
317+
));
286318

287-
if weights_vec_len == 1 {
288-
return (0..N_LANES)
289-
.map(|i| {
290-
let inversed_domain_points_vanishing_evaluated_at_point =
291-
inversed_domain_points_vanishing_evaluated_at_point[0].to_array();
292-
if i % 2 == 0 {
293-
inversed_domain_points_vanishing_evaluated_at_point[i]
294-
* (weights_first_half * coset_vanishing_evaluated_at_point)
295-
} else {
296-
inversed_domain_points_vanishing_evaluated_at_point[i]
297-
* (weights_second_half * coset_vanishing_evaluated_at_point)
298-
}
299-
})
300-
.collect();
301-
}
319+
#[cfg(not(feature = "parallel"))]
320+
let weights: Vec<PackedSecureField> = (0..weights_vec_len)
321+
.map(|i| {
322+
inversed_domain_points_vanishing_evaluated_at_point[i]
323+
* weights_half_and_half
324+
* coset_vanishing_evaluated_at_point
325+
})
326+
.collect();
302327

303-
let weights_half_and_half: PackedSecureField =
304-
PackedSecureField::from_array(std::array::from_fn(|i| {
305-
if i % 2 == 0 {
306-
weights_first_half
307-
} else {
308-
weights_second_half
309-
}
310-
}));
311-
let weights: Col<Self, SecureField> = (0..weights_vec_len)
328+
#[cfg(feature = "parallel")]
329+
let weights: Vec<PackedSecureField> = (0..weights_vec_len)
330+
.into_par_iter()
312331
.map(|i| {
313332
inversed_domain_points_vanishing_evaluated_at_point[i]
314333
* weights_half_and_half
315334
* coset_vanishing_evaluated_at_point
316335
})
317336
.collect();
318337

319-
weights
338+
Col::<Self, SecureField> {
339+
data: weights,
340+
length: domain.size(),
341+
}
320342
}
321343

322344
fn barycentric_eval_at_point(
@@ -334,15 +356,40 @@ impl PolyOps for SimdBackend {
334356
}
335357

336358
if evals.domain.size() < N_LANES {
359+
#[cfg(not(feature = "parallel"))]
337360
return (0..evals.domain.size()).fold(SecureField::zero(), |acc, i| {
338361
acc + (weights.at(i) * evals.values.at(i))
339362
});
363+
364+
#[cfg(feature = "parallel")]
365+
return (0..evals.domain.size())
366+
.into_par_iter()
367+
.fold(SecureField::zero, |acc: SecureField, i: usize| {
368+
acc + (weights.at(i) * evals.values.at(i))
369+
})
370+
.sum::<SecureField>();
371+
} else {
372+
#[cfg(not(feature = "parallel"))]
373+
return (0..evals.domain.size().div_ceil(N_LANES))
374+
.fold(PackedSecureField::zero(), |acc, i| {
375+
acc + (weights.data[i] * evals.values.data[i])
376+
})
377+
.pointwise_sum();
378+
379+
#[cfg(feature = "parallel")]
380+
return (0..evals.domain.size().div_ceil(N_LANES))
381+
.into_par_iter()
382+
.fold(
383+
PackedSecureField::zero,
384+
|acc: PackedSecureField, i: usize| {
385+
acc + (weights.data[i] * evals.values.data[i])
386+
},
387+
)
388+
.sum::<PackedSecureField>()
389+
.to_array()
390+
.into_par_iter()
391+
.sum::<SecureField>();
340392
}
341-
(0..evals.domain.size().div_ceil(N_LANES))
342-
.fold(PackedSecureField::zero(), |acc, i| {
343-
acc + (weights.data[i] * evals.values.data[i])
344-
})
345-
.pointwise_sum()
346393
}
347394

348395
fn extend(poly: &CirclePoly<Self>, log_size: u32) -> CirclePoly<Self> {

0 commit comments

Comments
 (0)