@@ -3,17 +3,19 @@ use std::ops::{Deref, Index};
33
44use educe:: Educe ;
55use itertools:: Itertools ;
6- use num_traits:: Zero ;
6+ use num_traits:: { One , Zero } ;
77
88use super :: { CircleDomain , CirclePoly , PolyOps } ;
99use crate :: core:: backend:: cpu:: CpuCircleEvaluation ;
10+ use crate :: core:: backend:: simd:: m31:: N_LANES ;
11+ use crate :: core:: backend:: simd:: qm31:: PackedSecureField ;
1012use crate :: core:: backend:: simd:: SimdBackend ;
1113use crate :: core:: backend:: { Col , Column , ColumnOps , CpuBackend } ;
1214use crate :: core:: circle:: { CirclePoint , CirclePointIndex , Coset } ;
1315use crate :: core:: constraints:: { coset_vanishing, coset_vanishing_derivative, point_vanishing} ;
1416use crate :: core:: fields:: m31:: BaseField ;
1517use crate :: core:: fields:: qm31:: SecureField ;
16- use crate :: core:: fields:: ExtensionOf ;
18+ use crate :: core:: fields:: { batch_inverse_in_place , ExtensionOf } ;
1719use crate :: core:: poly:: circle:: CanonicCoset ;
1820use crate :: core:: poly:: twiddles:: TwiddleTree ;
1921use crate :: core:: poly:: { BitReversedOrder , NaturalOrder } ;
@@ -229,13 +231,115 @@ fn barycentric_eval_at_point(
229231 } )
230232}
231233
234+ // TODO(Gali): Remove.
235+ #[ allow( dead_code) ]
236+ /// Computes the weights for Barycentric Lagrange interpolation for point `p` on `coset`.
237+ /// `p` must not be in the domain. For more information, see [`barycentric_weights`].
238+ fn simd_barycentric_weights (
239+ coset : CanonicCoset ,
240+ p : CirclePoint < SecureField > ,
241+ ) -> Col < SimdBackend , SecureField > {
242+ let domain = coset. circle_domain ( ) ;
243+ let weights_vec_len = domain. size ( ) . div_ceil ( N_LANES ) ;
244+
245+ // S_i(i) is invariant under G_(n−1) and alternate under J, so we can calculate only 2 values
246+ let p_first_half_coset = domain. at ( 0 ) . into_ef :: < SecureField > ( ) ;
247+ let si_i_first_half = SecureField :: one ( )
248+ / ( ( p_first_half_coset. y * SecureField :: from ( -2 ) )
249+ * coset_vanishing_derivative (
250+ Coset :: new ( CirclePointIndex :: generator ( ) , domain. log_size ( ) ) ,
251+ p_first_half_coset,
252+ ) ) ;
253+ let p_second_half_coset = domain. at ( domain. half_coset . size ( ) ) . into_ef :: < SecureField > ( ) ;
254+ let si_i_second_half = SecureField :: one ( )
255+ / ( ( p_second_half_coset. y * SecureField :: from ( -2 ) )
256+ * coset_vanishing_derivative (
257+ Coset :: new ( CirclePointIndex :: generator ( ) , domain. log_size ( ) ) ,
258+ p_second_half_coset,
259+ ) ) ;
260+
261+ // TODO(Gali): Optimize to a batched point_vanishing()
262+ let vi_p = ( 0 ..weights_vec_len)
263+ . map ( |i| {
264+ PackedSecureField :: from_array ( std:: array:: from_fn ( |j| {
265+ if domain. size ( ) <= i * N_LANES + j {
266+ SecureField :: one ( )
267+ } else {
268+ point_vanishing (
269+ domain. at ( i * N_LANES + j) . into_ef :: < SecureField > ( ) ,
270+ p. into_ef :: < SecureField > ( ) ,
271+ )
272+ }
273+ } ) )
274+ } )
275+ . collect_vec ( ) ;
276+ let mut vi_p_inverse = vec ! [ unsafe { std:: mem:: zeroed( ) } ; weights_vec_len] ;
277+
278+ batch_inverse_in_place ( & vi_p, & mut vi_p_inverse) ;
279+
280+ let vn_p: SecureField = coset_vanishing (
281+ CanonicCoset :: new ( domain. log_size ( ) ) . coset ,
282+ p. into_ef :: < SecureField > ( ) ,
283+ ) ;
284+
285+ // TODO(Gali): Change weights order to bit-reverse order.
286+ if weights_vec_len == 1 {
287+ return ( 0 ..N_LANES )
288+ . map ( |i| {
289+ let vi_p_inverse = vi_p_inverse[ 0 ] . to_array ( ) ;
290+ if i < domain. size ( ) / 2 {
291+ vi_p_inverse[ i] * ( si_i_first_half * vn_p)
292+ } else {
293+ vi_p_inverse[ i] * ( si_i_second_half * vn_p)
294+ }
295+ } )
296+ . collect ( ) ;
297+ }
298+
299+ let weights: Col < SimdBackend , SecureField > = ( 0 ..weights_vec_len)
300+ . map ( |i| {
301+ if i < weights_vec_len / 2 {
302+ vi_p_inverse[ i] * ( si_i_first_half * vn_p)
303+ } else {
304+ vi_p_inverse[ i] * ( si_i_second_half * vn_p)
305+ }
306+ } )
307+ . collect ( ) ;
308+
309+ weights
310+ }
311+
312+ // TODO(Gali): Remove.
313+ #[ allow( dead_code) ]
314+ /// Evaluates a polynomial at a point using the barycentric interpolation formula,
315+ /// given its evaluations on a circle domain and precomputed barycentric weights for the domain
316+ /// at the sampled point. For more information, see [`barycentric_eval_at_point`]
317+ fn simd_barycentric_eval_at_point (
318+ evals : & CircleEvaluation < SimdBackend , BaseField , BitReversedOrder > ,
319+ weights : & Col < SimdBackend , SecureField > ,
320+ ) -> SecureField {
321+ let evals = evals. clone ( ) . bit_reverse ( ) ;
322+ if evals. domain . size ( ) < N_LANES {
323+ return ( 0 ..evals. domain . size ( ) ) . fold ( SecureField :: zero ( ) , |acc, i| {
324+ acc + ( weights. at ( i) * evals. values . at ( i) )
325+ } ) ;
326+ }
327+ ( 0 ..evals. domain . size ( ) . div_ceil ( N_LANES ) )
328+ . fold ( PackedSecureField :: zero ( ) , |acc, i| {
329+ acc + ( weights. data [ i] * evals. values . data [ i] )
330+ } )
331+ . pointwise_sum ( )
332+ }
333+
232334#[ cfg( test) ]
233335mod tests {
234336 use super :: * ;
235337 use crate :: core:: backend:: cpu:: { CpuCircleEvaluation , CpuCirclePoly } ;
338+ use crate :: core:: backend:: simd:: column:: BaseColumn ;
339+ use crate :: core:: backend:: simd:: SimdBackend ;
236340 use crate :: core:: circle:: { CirclePoint , Coset } ;
237341 use crate :: core:: fields:: m31:: BaseField ;
238- use crate :: core:: poly:: circle:: CanonicCoset ;
342+ use crate :: core:: poly:: circle:: { CanonicCoset , CirclePoly } ;
239343 use crate :: core:: poly:: NaturalOrder ;
240344 use crate :: m31;
241345
@@ -281,7 +385,7 @@ mod tests {
281385 }
282386
283387 #[ test]
284- fn test_barycentric_evaluation ( ) {
388+ fn test_cpu_barycentric_evaluation ( ) {
285389 let poly = CpuCirclePoly :: new (
286390 [ 691 , 805673 , 5 , 435684 , 4832 , 23876431 , 197 , 897346068 ]
287391 . map ( BaseField :: from)
@@ -311,4 +415,38 @@ mod tests {
311415 "Barycentric evaluation should be equal to the polynomial evaluation"
312416 ) ;
313417 }
418+
419+ #[ test]
420+ fn test_simd_barycentric_evaluation ( ) {
421+ let poly = CirclePoly :: < SimdBackend > :: new ( BaseColumn :: from_cpu (
422+ [ 691 , 805673 , 5 , 435684 , 4832 , 23876431 , 197 , 897346068 ]
423+ . map ( BaseField :: from)
424+ . to_vec ( ) ,
425+ ) ) ;
426+ let s = CanonicCoset :: new ( 10 ) ;
427+ let domain = s. circle_domain ( ) ;
428+ let eval = poly. evaluate ( domain) ;
429+ let sampled_points = [
430+ CirclePoint :: get_point ( 348 ) ,
431+ CirclePoint :: get_point ( 9736524 ) ,
432+ CirclePoint :: get_point ( 13 ) ,
433+ CirclePoint :: get_point ( 346752 ) ,
434+ ] ;
435+ let sampled_values = sampled_points
436+ . iter ( )
437+ . map ( |point| poly. eval_at_point ( * point) )
438+ . collect_vec ( ) ;
439+
440+ let sampled_barycentric_values = sampled_points
441+ . iter ( )
442+ . map ( |point| {
443+ simd_barycentric_eval_at_point ( & eval, & simd_barycentric_weights ( s, * point) )
444+ } )
445+ . collect_vec ( ) ;
446+
447+ assert_eq ! (
448+ sampled_barycentric_values, sampled_values,
449+ "Barycentric evaluation should be equal to the polynomial evaluation"
450+ ) ;
451+ }
314452}
0 commit comments