@@ -138,12 +138,14 @@ struct PushforwardJacobianPrep{
138
138
BS<: BatchSizeSettings ,
139
139
S<: AbstractVector{<:NTuple} ,
140
140
R<: AbstractVector{<:NTuple} ,
141
+ SE<: NTuple ,
141
142
E<: PushforwardPrep ,
142
143
} <: StandardJacobianPrep{SIG}
143
144
_sig:: Val{SIG}
144
145
batch_size_settings:: BS
145
146
batched_seeds:: S
146
147
batched_results:: R
148
+ seed_example:: SE
147
149
pushforward_prep:: E
148
150
end
149
151
@@ -152,12 +154,14 @@ struct PullbackJacobianPrep{
152
154
BS<: BatchSizeSettings ,
153
155
S<: AbstractVector{<:NTuple} ,
154
156
R<: AbstractVector{<:NTuple} ,
157
+ SE<: NTuple ,
155
158
E<: PullbackPrep ,
156
159
} <: StandardJacobianPrep{SIG}
157
160
_sig:: Val{SIG}
158
161
batch_size_settings:: BS
159
162
batched_seeds:: S
160
163
batched_results:: R
164
+ seed_example:: SE
161
165
pullback_prep:: E
162
166
end
163
167
@@ -211,11 +215,17 @@ function _prepare_jacobian_aux(
211
215
ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
212
216
]
213
217
batched_results = [ntuple (b -> similar (y), Val (B)) for _ in batched_seeds]
218
+ seed_example = ntuple (b -> zero (x), Val (B))
214
219
pushforward_prep = prepare_pushforward_nokwarg (
215
- strict, f_or_f!y... , backend, x, batched_seeds[ 1 ] , contexts...
220
+ strict, f_or_f!y... , backend, x, seed_example , contexts...
216
221
)
217
222
return PushforwardJacobianPrep (
218
- _sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
223
+ _sig,
224
+ batch_size_settings,
225
+ batched_seeds,
226
+ batched_results,
227
+ seed_example,
228
+ pushforward_prep,
219
229
)
220
230
end
221
231
@@ -236,11 +246,17 @@ function _prepare_jacobian_aux(
236
246
ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
237
247
]
238
248
batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
249
+ seed_example = ntuple (b -> zero (y), Val (B))
239
250
pullback_prep = prepare_pullback_nokwarg (
240
- strict, f_or_f!y... , backend, x, batched_seeds[ 1 ] , contexts...
251
+ strict, f_or_f!y... , backend, x, seed_example , contexts...
241
252
)
242
253
return PullbackJacobianPrep (
243
- _sig, batch_size_settings, batched_seeds, batched_results, pullback_prep
254
+ _sig,
255
+ batch_size_settings,
256
+ batched_seeds,
257
+ batched_results,
258
+ seed_example,
259
+ pullback_prep,
244
260
)
245
261
end
246
262
@@ -363,11 +379,11 @@ function _jacobian_aux(
363
379
x,
364
380
contexts:: Vararg{Context,C} ,
365
381
) where {FY,SIG,B,aligned,C}
366
- (; batch_size_settings, batched_seeds, pushforward_prep) = prep
382
+ (; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep
367
383
(; A, B_last) = batch_size_settings
368
384
369
385
pushforward_prep_same = prepare_pushforward_same_point (
370
- f_or_f!y... , pushforward_prep, backend, x, batched_seeds[ 1 ] , contexts...
386
+ f_or_f!y... , pushforward_prep, backend, x, seed_example , contexts...
371
387
)
372
388
373
389
jac = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -419,11 +435,11 @@ function _jacobian_aux(
419
435
x,
420
436
contexts:: Vararg{Context,C} ,
421
437
) where {FY,SIG,B,aligned,C}
422
- (; batch_size_settings, batched_seeds, pullback_prep) = prep
438
+ (; batch_size_settings, batched_seeds, seed_example, pullback_prep) = prep
423
439
(; A, B_last) = batch_size_settings
424
440
425
441
pullback_prep_same = prepare_pullback_same_point (
426
- f_or_f!y... , prep . pullback_prep, backend, x, batched_seeds[ 1 ] , contexts...
442
+ f_or_f!y... , pullback_prep, backend, x, seed_example , contexts...
427
443
)
428
444
429
445
jac = mapreduce (vcat, eachindex (batched_seeds)) do a
@@ -451,11 +467,13 @@ function _jacobian_aux!(
451
467
x,
452
468
contexts:: Vararg{Context,C} ,
453
469
) where {FY,SIG,B,C}
454
- (; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep
470
+ (;
471
+ batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep
472
+ ) = prep
455
473
(; N) = batch_size_settings
456
474
457
475
pushforward_prep_same = prepare_pushforward_same_point (
458
- f_or_f!y... , pushforward_prep, backend, x, batched_seeds[ 1 ] , contexts...
476
+ f_or_f!y... , pushforward_prep, backend, x, seed_example , contexts...
459
477
)
460
478
461
479
for a in eachindex (batched_seeds, batched_results)
@@ -487,11 +505,12 @@ function _jacobian_aux!(
487
505
x,
488
506
contexts:: Vararg{Context,C} ,
489
507
) where {FY,SIG,B,C}
490
- (; batch_size_settings, batched_seeds, batched_results, pullback_prep) = prep
508
+ (; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) =
509
+ prep
491
510
(; N) = batch_size_settings
492
511
493
512
pullback_prep_same = prepare_pullback_same_point (
494
- f_or_f!y... , pullback_prep, backend, x, batched_seeds[ 1 ] , contexts...
513
+ f_or_f!y... , pullback_prep, backend, x, seed_example , contexts...
495
514
)
496
515
497
516
for a in eachindex (batched_seeds, batched_results)
0 commit comments