@@ -3,6 +3,8 @@ use std::mem::transmute;
33use std:: simd:: Simd ;
44
55use bytemuck:: Zeroable ;
6+ use itertools:: Itertools ;
7+ use num_traits:: { One , Zero } ;
68#[ cfg( feature = "parallel" ) ]
79use rayon:: prelude:: * ;
810
@@ -14,10 +16,11 @@ use crate::core::backend::cpu::circle::slow_precompute_twiddles;
1416use crate :: core:: backend:: simd:: column:: BaseColumn ;
1517use crate :: core:: backend:: simd:: m31:: PackedM31 ;
1618use crate :: core:: backend:: { Col , Column , CpuBackend } ;
17- use crate :: core:: circle:: { CirclePoint , Coset , M31_CIRCLE_LOG_ORDER } ;
19+ use crate :: core:: circle:: { CirclePoint , CirclePointIndex , Coset , M31_CIRCLE_LOG_ORDER } ;
20+ use crate :: core:: constraints:: { coset_vanishing, coset_vanishing_derivative, point_vanishing} ;
1821use crate :: core:: fields:: m31:: BaseField ;
1922use crate :: core:: fields:: qm31:: SecureField ;
20- use crate :: core:: fields:: { Field , FieldExpOps } ;
23+ use crate :: core:: fields:: { batch_inverse_in_place , Field , FieldExpOps } ;
2124use crate :: core:: poly:: circle:: {
2225 CanonicCoset , CircleDomain , CircleEvaluation , CirclePoly , PolyOps ,
2326} ;
@@ -221,6 +224,84 @@ impl PolyOps for SimdBackend {
221224 ( sum * twiddle_lows) . pointwise_sum ( )
222225 }
223226
227+ fn barycentric_weights (
228+ coset : CanonicCoset ,
229+ p : CirclePoint < SecureField > ,
230+ ) -> Col < SimdBackend , SecureField > {
231+ let domain = coset. circle_domain ( ) ;
232+ let weights_vec_len = domain. size ( ) . div_ceil ( N_LANES ) ;
233+ if weights_vec_len == 1 {
234+ return Col :: < SimdBackend , SecureField > :: from_iter ( CircleEvaluation :: <
235+ CpuBackend ,
236+ BaseField ,
237+ BitReversedOrder ,
238+ > :: barycentric_weights (
239+ coset, p
240+ ) ) ;
241+ }
242+
243+ let p_0 = domain. at ( 0 ) . into_ef :: < SecureField > ( ) ;
244+ let si_0 = SecureField :: one ( )
245+ / ( ( p_0. y * SecureField :: from ( -2 ) )
246+ * coset_vanishing_derivative (
247+ Coset :: new ( CirclePointIndex :: generator ( ) , domain. log_size ( ) ) ,
248+ p_0,
249+ ) ) ;
250+
251+ // TODO(Gali): Optimize to a batched point_vanishing()
252+ let vi_p = ( 0 ..weights_vec_len)
253+ . map ( |i| {
254+ PackedSecureField :: from_array ( std:: array:: from_fn ( |j| {
255+ point_vanishing (
256+ domain. at ( i * N_LANES + j) . into_ef :: < SecureField > ( ) ,
257+ p. into_ef :: < SecureField > ( ) ,
258+ )
259+ } ) )
260+ } )
261+ . collect_vec ( ) ;
262+ let mut vi_p_inverse = Vec :: with_capacity ( weights_vec_len) ;
263+ #[ allow( clippy:: uninit_vec) ]
264+ unsafe {
265+ vi_p_inverse. set_len ( weights_vec_len)
266+ } ;
267+ batch_inverse_in_place ( & vi_p, & mut vi_p_inverse) ;
268+
269+ let vn_p: SecureField = coset_vanishing (
270+ CanonicCoset :: new ( domain. log_size ( ) ) . coset ,
271+ p. into_ef :: < SecureField > ( ) ,
272+ ) ;
273+
274+ let si_0_vn_p = PackedSecureField :: broadcast ( si_0 * vn_p) ;
275+
276+ // TODO(Gali): Change weights order to bit-reverse order.
277+ // S_i(i) is invariant under G_(n−1) and alternate under J, meaning the S_i(i) values are
278+ // the same for each half coset, and the second half coset values are the conjugate
279+ // of the first half coset values.
280+ let weights: Col < SimdBackend , SecureField > = ( 0 ..weights_vec_len)
281+ . map ( |i| {
282+ if i < weights_vec_len / 2 {
283+ vi_p_inverse[ i] * si_0_vn_p
284+ } else {
285+ vi_p_inverse[ i] * -si_0_vn_p
286+ }
287+ } )
288+ . collect ( ) ;
289+
290+ weights
291+ }
292+
293+ fn barycentric_eval_at_point (
294+ evals : & CircleEvaluation < SimdBackend , BaseField , BitReversedOrder > ,
295+ weights : & Col < SimdBackend , SecureField > ,
296+ ) -> SecureField {
297+ let evals = evals. clone ( ) . bit_reverse ( ) ;
298+ ( 0 ..evals. domain . size ( ) . div_ceil ( N_LANES ) )
299+ . fold ( PackedSecureField :: zero ( ) , |acc, i| {
300+ acc + ( weights. data [ i] * evals. values . data [ i] )
301+ } )
302+ . pointwise_sum ( )
303+ }
304+
224305 fn extend ( poly : & CirclePoly < Self > , log_size : u32 ) -> CirclePoly < Self > {
225306 // TODO(shahars): Get rid of extends.
226307 poly. evaluate ( CanonicCoset :: new ( log_size) . circle_domain ( ) )
0 commit comments