@@ -3,17 +3,19 @@ use std::ops::{Deref, Index};
33
44use educe:: Educe ;
55use itertools:: Itertools ;
6- use num_traits:: Zero ;
6+ use num_traits:: { One , Zero } ;
77
88use super :: { CircleDomain , CirclePoly , PolyOps } ;
99use crate :: core:: backend:: cpu:: CpuCircleEvaluation ;
10+ use crate :: core:: backend:: simd:: m31:: N_LANES ;
11+ use crate :: core:: backend:: simd:: qm31:: PackedSecureField ;
1012use crate :: core:: backend:: simd:: SimdBackend ;
1113use crate :: core:: backend:: { Col , Column , ColumnOps , CpuBackend } ;
1214use crate :: core:: circle:: { CirclePoint , CirclePointIndex , Coset } ;
1315use crate :: core:: constraints:: { coset_vanishing, coset_vanishing_derivative, point_vanishing} ;
1416use crate :: core:: fields:: m31:: BaseField ;
1517use crate :: core:: fields:: qm31:: SecureField ;
16- use crate :: core:: fields:: ExtensionOf ;
18+ use crate :: core:: fields:: { batch_inverse_in_place , ExtensionOf } ;
1719use crate :: core:: poly:: circle:: CanonicCoset ;
1820use crate :: core:: poly:: twiddles:: TwiddleTree ;
1921use crate :: core:: poly:: { BitReversedOrder , NaturalOrder } ;
@@ -229,13 +231,108 @@ fn barycentric_eval_at_point(
229231 } )
230232}
231233
234+ // TODO(Gali): Remove.
235+ #[ allow( dead_code) ]
236+ /// Computes the weights for Barycentric Lagrange interpolation for point `p` on `coset`.
237+ /// `p` must not be in the domain. For more information, see [`barycentric_weights`].
238+ fn simd_barycentric_weights (
239+ coset : CanonicCoset ,
240+ p : CirclePoint < SecureField > ,
241+ ) -> Col < SimdBackend , SecureField > {
242+ let domain = coset. circle_domain ( ) ;
243+ let weights_vec_len = domain. size ( ) . div_ceil ( N_LANES ) ;
244+
245+ // S_i(i) is invariant under G_(n−1) and alternate under J, so we can calculate only 2 values
246+ let p_first_half_coset = domain. at ( 0 ) . into_ef :: < SecureField > ( ) ;
247+ let si_i_first_half = SecureField :: one ( )
248+ / ( ( p_first_half_coset. y * SecureField :: from ( -2 ) )
249+ * coset_vanishing_derivative (
250+ Coset :: new ( CirclePointIndex :: generator ( ) , domain. log_size ( ) ) ,
251+ p_first_half_coset,
252+ ) ) ;
253+ let si_i_second_half = -si_i_first_half;
254+
255+ // TODO(Gali): Optimize to a batched point_vanishing()
256+ let vi_p = ( 0 ..weights_vec_len)
257+ . map ( |i| {
258+ PackedSecureField :: from_array ( std:: array:: from_fn ( |j| {
259+ if domain. size ( ) <= i * N_LANES + j {
260+ SecureField :: one ( )
261+ } else {
262+ point_vanishing (
263+ domain. at ( i * N_LANES + j) . into_ef :: < SecureField > ( ) ,
264+ p. into_ef :: < SecureField > ( ) ,
265+ )
266+ }
267+ } ) )
268+ } )
269+ . collect_vec ( ) ;
270+ let mut vi_p_inverse = vec ! [ unsafe { std:: mem:: zeroed( ) } ; weights_vec_len] ;
271+
272+ batch_inverse_in_place ( & vi_p, & mut vi_p_inverse) ;
273+
274+ let vn_p: SecureField = coset_vanishing (
275+ CanonicCoset :: new ( domain. log_size ( ) ) . coset ,
276+ p. into_ef :: < SecureField > ( ) ,
277+ ) ;
278+
279+ // TODO(Gali): Change weights order to bit-reverse order.
280+ if weights_vec_len == 1 {
281+ return ( 0 ..N_LANES )
282+ . map ( |i| {
283+ if i >= domain. size ( ) {
284+ SecureField :: zero ( )
285+ } else {
286+ let vi_p_inverse = vi_p_inverse[ 0 ] . to_array ( ) ;
287+ if i < domain. size ( ) / 2 {
288+ vi_p_inverse[ i] * ( si_i_first_half * vn_p)
289+ } else {
290+ vi_p_inverse[ i] * ( si_i_second_half * vn_p)
291+ }
292+ }
293+ } )
294+ . collect ( ) ;
295+ }
296+
297+ let weights: Col < SimdBackend , SecureField > = ( 0 ..weights_vec_len)
298+ . map ( |i| {
299+ if i < weights_vec_len / 2 {
300+ vi_p_inverse[ i] * ( si_i_first_half * vn_p)
301+ } else {
302+ vi_p_inverse[ i] * ( si_i_second_half * vn_p)
303+ }
304+ } )
305+ . collect ( ) ;
306+
307+ weights
308+ }
309+
310+ // TODO(Gali): Remove.
311+ #[ allow( dead_code) ]
312+ /// Evaluates a polynomial at a point using the barycentric interpolation formula,
313+ /// given its evaluations on a circle domain and precomputed barycentric weights for the domain
314+ /// at the sampled point. For more information, see [`barycentric_eval_at_point`]
315+ fn simd_barycentric_eval_at_point (
316+ evals : & CircleEvaluation < SimdBackend , BaseField , BitReversedOrder > ,
317+ weights : & Col < SimdBackend , SecureField > ,
318+ ) -> SecureField {
319+ let evals = evals. clone ( ) . bit_reverse ( ) ;
320+ ( 0 ..evals. domain . size ( ) . div_ceil ( N_LANES ) )
321+ . fold ( PackedSecureField :: zero ( ) , |acc, i| {
322+ acc + ( weights. data [ i] * evals. values . data [ i] )
323+ } )
324+ . pointwise_sum ( )
325+ }
326+
232327#[ cfg( test) ]
233328mod tests {
234329 use super :: * ;
235330 use crate :: core:: backend:: cpu:: { CpuCircleEvaluation , CpuCirclePoly } ;
331+ use crate :: core:: backend:: simd:: column:: BaseColumn ;
332+ use crate :: core:: backend:: simd:: SimdBackend ;
236333 use crate :: core:: circle:: { CirclePoint , Coset } ;
237334 use crate :: core:: fields:: m31:: BaseField ;
238- use crate :: core:: poly:: circle:: CanonicCoset ;
335+ use crate :: core:: poly:: circle:: { CanonicCoset , CirclePoly } ;
239336 use crate :: core:: poly:: NaturalOrder ;
240337 use crate :: m31;
241338
@@ -281,7 +378,7 @@ mod tests {
281378 }
282379
283380 #[ test]
284- fn test_barycentric_evaluation ( ) {
381+ fn test_cpu_barycentric_evaluation ( ) {
285382 let poly = CpuCirclePoly :: new (
286383 [ 691 , 805673 , 5 , 435684 , 4832 , 23876431 , 197 , 897346068 ]
287384 . map ( BaseField :: from)
@@ -311,4 +408,65 @@ mod tests {
311408 "Barycentric evaluation should be equal to the polynomial evaluation"
312409 ) ;
313410 }
411+
412+ #[ test]
413+ fn test_simd_barycentric_evaluation ( ) {
414+ let poly = CirclePoly :: < SimdBackend > :: new ( BaseColumn :: from_cpu (
415+ [ 691 , 805673 , 5 , 435684 , 4832 , 23876431 , 197 , 897346068 ]
416+ . map ( BaseField :: from)
417+ . to_vec ( ) ,
418+ ) ) ;
419+ let s = CanonicCoset :: new ( 10 ) ;
420+ let domain = s. circle_domain ( ) ;
421+ let eval = poly. evaluate ( domain) ;
422+ let sampled_points = [
423+ CirclePoint :: get_point ( 348 ) ,
424+ CirclePoint :: get_point ( 9736524 ) ,
425+ CirclePoint :: get_point ( 13 ) ,
426+ CirclePoint :: get_point ( 346752 ) ,
427+ ] ;
428+ let sampled_values = sampled_points
429+ . iter ( )
430+ . map ( |point| poly. eval_at_point ( * point) )
431+ . collect_vec ( ) ;
432+
433+ let sampled_barycentric_values = sampled_points
434+ . iter ( )
435+ . map ( |point| {
436+ simd_barycentric_eval_at_point ( & eval, & simd_barycentric_weights ( s, * point) )
437+ } )
438+ . collect_vec ( ) ;
439+
440+ assert_eq ! (
441+ sampled_barycentric_values, sampled_values,
442+ "Barycentric evaluation should be equal to the polynomial evaluation"
443+ ) ;
444+ }
445+
446+ #[ test]
447+ fn test_simd_barycentric_weights ( ) {
448+ let s = CanonicCoset :: new ( 10 ) ;
449+ let sampled_points = [
450+ CirclePoint :: get_point ( 348 ) ,
451+ CirclePoint :: get_point ( 9736524 ) ,
452+ CirclePoint :: get_point ( 13 ) ,
453+ CirclePoint :: get_point ( 346752 ) ,
454+ ] ;
455+
456+ let cpu_barycentric_values = sampled_points
457+ . iter ( )
458+ . map ( |point| barycentric_weights ( s, * point) )
459+ . collect_vec ( ) ;
460+ let simd_barycentric_values = sampled_points
461+ . iter ( )
462+ . map ( |point| simd_barycentric_weights ( s, * point) )
463+ . collect_vec ( ) ;
464+
465+ cpu_barycentric_values
466+ . iter ( )
467+ . zip ( simd_barycentric_values. iter ( ) )
468+ . for_each ( |( cpu_weights, simd_weights) | {
469+ assert_eq ! ( * cpu_weights, simd_weights. to_cpu( ) ) ;
470+ } ) ;
471+ }
314472}
0 commit comments