@@ -3,6 +3,7 @@ use std::mem::transmute;
33use std:: simd:: Simd ;
44
55use bytemuck:: Zeroable ;
6+ #[ cfg( not( feature = "parallel" ) ) ]
67use itertools:: Itertools ;
78use num_traits:: { One , Zero } ;
89#[ cfg( feature = "parallel" ) ]
@@ -254,7 +255,17 @@ impl PolyOps for SimdBackend {
254255 p_0_inverse,
255256 ) ) ;
256257
258+ let weights_half_and_half: PackedSecureField =
259+ PackedSecureField :: from_array ( std:: array:: from_fn ( |i| {
260+ if i % 2 == 0 {
261+ weights_first_half
262+ } else {
263+ weights_second_half
264+ }
265+ } ) ) ;
266+
257267 // TODO(Gali): Optimize to a batched point_vanishing()
268+ #[ cfg( not( feature = "parallel" ) ) ]
258269 let domain_points_vanishing_evaluated_at_point = ( 0 ..weights_vec_len)
259270 . map ( |i| {
260271 PackedSecureField :: from_array ( std:: array:: from_fn ( |j| {
@@ -271,6 +282,26 @@ impl PolyOps for SimdBackend {
271282 } ) )
272283 } )
273284 . collect_vec ( ) ;
285+ #[ cfg( feature = "parallel" ) ]
286+ let domain_points_vanishing_evaluated_at_point: Vec < PackedSecureField > = ( 0
287+ ..weights_vec_len)
288+ . into_par_iter ( )
289+ . map ( |i| {
290+ PackedSecureField :: from_array ( std:: array:: from_fn ( |j| {
291+ if domain. size ( ) <= bit_reverse_index ( i * N_LANES + j, log_size) {
292+ SecureField :: one ( )
293+ } else {
294+ point_vanishing (
295+ domain
296+ . at ( bit_reverse_index ( i * N_LANES + j, log_size) )
297+ . into_ef :: < SecureField > ( ) ,
298+ sample_point. into_ef :: < SecureField > ( ) ,
299+ )
300+ }
301+ } ) )
302+ } )
303+ . collect ( ) ;
304+
274305 let mut inversed_domain_points_vanishing_evaluated_at_point =
275306 vec ! [ unsafe { std:: mem:: zeroed( ) } ; weights_vec_len] ;
276307
@@ -279,44 +310,35 @@ impl PolyOps for SimdBackend {
279310 & mut inversed_domain_points_vanishing_evaluated_at_point,
280311 ) ;
281312
282- let coset_vanishing_evaluated_at_point: SecureField = coset_vanishing (
283- CanonicCoset :: new ( domain. log_size ( ) ) . coset ,
284- sample_point. into_ef :: < SecureField > ( ) ,
285- ) ;
313+ let coset_vanishing_evaluated_at_point: PackedSecureField =
314+ PackedSecureField :: broadcast ( coset_vanishing (
315+ CanonicCoset :: new ( domain. log_size ( ) ) . coset ,
316+ sample_point. into_ef :: < SecureField > ( ) ,
317+ ) ) ;
286318
287- if weights_vec_len == 1 {
288- return ( 0 ..N_LANES )
289- . map ( |i| {
290- let inversed_domain_points_vanishing_evaluated_at_point =
291- inversed_domain_points_vanishing_evaluated_at_point[ 0 ] . to_array ( ) ;
292- if i % 2 == 0 {
293- inversed_domain_points_vanishing_evaluated_at_point[ i]
294- * ( weights_first_half * coset_vanishing_evaluated_at_point)
295- } else {
296- inversed_domain_points_vanishing_evaluated_at_point[ i]
297- * ( weights_second_half * coset_vanishing_evaluated_at_point)
298- }
299- } )
300- . collect ( ) ;
301- }
319+ #[ cfg( not( feature = "parallel" ) ) ]
320+ let weights: Vec < PackedSecureField > = ( 0 ..weights_vec_len)
321+ . map ( |i| {
322+ inversed_domain_points_vanishing_evaluated_at_point[ i]
323+ * weights_half_and_half
324+ * coset_vanishing_evaluated_at_point
325+ } )
326+ . collect ( ) ;
302327
303- let weights_half_and_half: PackedSecureField =
304- PackedSecureField :: from_array ( std:: array:: from_fn ( |i| {
305- if i % 2 == 0 {
306- weights_first_half
307- } else {
308- weights_second_half
309- }
310- } ) ) ;
311- let weights: Col < Self , SecureField > = ( 0 ..weights_vec_len)
328+ #[ cfg( feature = "parallel" ) ]
329+ let weights: Vec < PackedSecureField > = ( 0 ..weights_vec_len)
330+ . into_par_iter ( )
312331 . map ( |i| {
313332 inversed_domain_points_vanishing_evaluated_at_point[ i]
314333 * weights_half_and_half
315334 * coset_vanishing_evaluated_at_point
316335 } )
317336 . collect ( ) ;
318337
319- weights
338+ Col :: < Self , SecureField > {
339+ data : weights,
340+ length : domain. size ( ) ,
341+ }
320342 }
321343
322344 fn barycentric_eval_at_point (
@@ -334,15 +356,40 @@ impl PolyOps for SimdBackend {
334356 }
335357
336358 if evals. domain . size ( ) < N_LANES {
359+ #[ cfg( not( feature = "parallel" ) ) ]
337360 return ( 0 ..evals. domain . size ( ) ) . fold ( SecureField :: zero ( ) , |acc, i| {
338361 acc + ( weights. at ( i) * evals. values . at ( i) )
339362 } ) ;
363+
364+ #[ cfg( feature = "parallel" ) ]
365+ return ( 0 ..evals. domain . size ( ) )
366+ . into_par_iter ( )
367+ . fold ( SecureField :: zero, |acc : SecureField , i : usize | {
368+ acc + ( weights. at ( i) * evals. values . at ( i) )
369+ } )
370+ . sum :: < SecureField > ( ) ;
371+ } else {
372+ #[ cfg( not( feature = "parallel" ) ) ]
373+ return ( 0 ..evals. domain . size ( ) . div_ceil ( N_LANES ) )
374+ . fold ( PackedSecureField :: zero ( ) , |acc, i| {
375+ acc + ( weights. data [ i] * evals. values . data [ i] )
376+ } )
377+ . pointwise_sum ( ) ;
378+
379+ #[ cfg( feature = "parallel" ) ]
380+ return ( 0 ..evals. domain . size ( ) . div_ceil ( N_LANES ) )
381+ . into_par_iter ( )
382+ . fold (
383+ PackedSecureField :: zero,
384+ |acc : PackedSecureField , i : usize | {
385+ acc + ( weights. data [ i] * evals. values . data [ i] )
386+ } ,
387+ )
388+ . sum :: < PackedSecureField > ( )
389+ . to_array ( )
390+ . into_par_iter ( )
391+ . sum :: < SecureField > ( ) ;
340392 }
341- ( 0 ..evals. domain . size ( ) . div_ceil ( N_LANES ) )
342- . fold ( PackedSecureField :: zero ( ) , |acc, i| {
343- acc + ( weights. data [ i] * evals. values . data [ i] )
344- } )
345- . pointwise_sum ( )
346393 }
347394
348395 fn extend ( poly : & CirclePoly < Self > , log_size : u32 ) -> CirclePoly < Self > {
0 commit comments