Skip to content

Commit 3b7ac6d

Browse files
authored
Merge pull request #245 from wsmoses/main
Fix step size calculation if some results are non-deterministic (inf …
2 parents 87e0a26 + 09c6767 commit 3b7ac6d

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.12.32"
3+
version = "0.12.33"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/methods.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,17 +371,27 @@ function estimate_step(
371371
return _limit_step(m, x, step, acc)
372372
end
373373

374+
function finite_or_zero(fs::AbstractArray{<:Number})
375+
ifelse.(isfinite.(fs), fs, zero(fs))
376+
end
377+
378+
function finite_or_zero(fs::AbstractArray{<:AbstractArray})
379+
finite_or_zero.(fs)
380+
end
381+
374382
function _estimate_magnitudes(
375383
m::FiniteDifferenceMethod{P,Q}, f::TF, x::T,
376384
) where {P,Q,TF,T<:AbstractFloat}
377385
step = first(estimate_step(m, f, x))
378386
fs = _eval_function(m, f, x, step)
387+
fs = finite_or_zero(fs)
379388
# Estimate magnitude of `∇f` in a neighbourhood of `x`.
380389
∇fs = SVector{3}(
381390
_compute_estimate(m, fs, x, step, m.coefs_neighbourhood[1]),
382391
_compute_estimate(m, fs, x, step, m.coefs_neighbourhood[2]),
383392
_compute_estimate(m, fs, x, step, m.coefs_neighbourhood[3])
384393
)
394+
∇fs = finite_or_zero(∇fs)
385395
∇f_magnitude = maximum(maximum.(abs, ∇fs))
386396
# Estimate magnitude of `f` in a neighbourhood of `x`.
387397
f_magnitude = maximum(maximum.(abs, fs))

test/grad.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,26 @@ using FiniteDifferences: grad, jacobian, _jvp, jvp, j′vp, _j′vp, to_vec
217217
@test [real(ȳ), imag(ȳ)] Jy'z̄_vec
218218
end
219219
end
220+
221+
using LinearAlgebra
222+
223+
function partial_nan_returning(x)
224+
return Float64[NaN, x]
225+
end
226+
227+
randvar = 1
228+
function partial_nondet_returning(x)
229+
global randvar
230+
y = Float64[randvar, x]
231+
randvar += 1
232+
return y
233+
end
234+
235+
@testset "jvp: Estimate step correctly for when some terms are nan/infinite" begin
236+
fdm = FiniteDifferences.central_fdm(5, 1)
237+
res = jvp(fdm, partial_nan_returning, (3.1, 2.7))
238+
@test res[2] 2.7
239+
240+
res = jvp(fdm, partial_nondet_returning, (3.1, 2.7))
241+
@test res[2] 2.7
242+
end

0 commit comments

Comments
 (0)