@@ -84,13 +84,15 @@ struct HVPGradientHessianPrep{
84
84
BS<: BatchSizeSettings ,
85
85
S<: AbstractVector{<:NTuple} ,
86
86
R<: AbstractVector{<:NTuple} ,
87
+ SE<: NTuple ,
87
88
E2<: HVPPrep ,
88
89
E1<: GradientPrep ,
89
90
} <: HessianPrep{SIG}
90
91
_sig:: Val{SIG}
91
92
batch_size_settings:: BS
92
93
batched_seeds:: S
93
94
batched_results:: R
95
+ seed_example:: SE
94
96
hvp_prep:: E2
95
97
gradient_prep:: E1
96
98
end
@@ -119,10 +121,17 @@ function _prepare_hessian_aux(
119
121
ntuple (b -> seeds[1 + ((a - 1 ) * B + (b - 1 )) % N], Val (B)) for a in 1 : A
120
122
]
121
123
batched_results = [ntuple (b -> similar (x), Val (B)) for _ in batched_seeds]
122
- hvp_prep = prepare_hvp_nokwarg (strict, f, backend, x, batched_seeds[1 ], contexts... )
124
+ seed_example = ntuple (b -> basis (x), Val (B))
125
+ hvp_prep = prepare_hvp_nokwarg (strict, f, backend, x, seed_example, contexts... )
123
126
gradient_prep = prepare_gradient_nokwarg (strict, f, inner (backend), x, contexts... )
124
127
return HVPGradientHessianPrep (
125
- _sig, batch_size_settings, batched_seeds, batched_results, hvp_prep, gradient_prep
128
+ _sig,
129
+ batch_size_settings,
130
+ batched_seeds,
131
+ batched_results,
132
+ seed_example,
133
+ hvp_prep,
134
+ gradient_prep,
126
135
)
127
136
end
128
137
@@ -150,11 +159,11 @@ function hessian(
150
159
contexts:: Vararg{Context,C} ,
151
160
) where {F,SIG,B,aligned,C}
152
161
check_prep (f, prep, backend, x, contexts... )
153
- (; batch_size_settings, batched_seeds, hvp_prep) = prep
162
+ (; batch_size_settings, batched_seeds, seed_example, hvp_prep) = prep
154
163
(; A, B_last) = batch_size_settings
155
164
156
165
hvp_prep_same = prepare_hvp_same_point (
157
- f, hvp_prep, backend, x, batched_seeds[ 1 ] , contexts...
166
+ f, hvp_prep, backend, x, seed_example , contexts...
158
167
)
159
168
160
169
hess = mapreduce (hcat, eachindex (batched_seeds)) do a
@@ -178,11 +187,11 @@ function hessian!(
178
187
contexts:: Vararg{Context,C} ,
179
188
) where {F,SIG,B,C}
180
189
check_prep (f, prep, backend, x, contexts... )
181
- (; batch_size_settings, batched_seeds, batched_results, hvp_prep) = prep
190
+ (; batch_size_settings, batched_seeds, batched_results, seed_example, hvp_prep) = prep
182
191
(; N) = batch_size_settings
183
192
184
193
hvp_prep_same = prepare_hvp_same_point (
185
- f, hvp_prep, backend, x, batched_seeds[ 1 ] , contexts...
194
+ f, hvp_prep, backend, x, seed_example , contexts...
186
195
)
187
196
188
197
for a in eachindex (batched_seeds, batched_results)
0 commit comments