diff --git a/Project.toml b/Project.toml index 374a9a7..955f7a2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Hankel" uuid = "74863788-d124-456e-a676-9b76578dd39e" authors = ["chrisbrahms <38351086+chrisbrahms@users.noreply.github.com>"] -version = "0.5.5" +version = "0.5.6" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,7 +11,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] -ChainRulesCore = "0.9.44, 0.10" +ChainRulesCore = "0.9.44, 0.10, 1" GSL = "0.6, 1" Roots = "0.8, 1" SpecialFunctions = "0.10, 1" diff --git a/src/Hankel.jl b/src/Hankel.jl index 40b4610..6cdb9dd 100644 --- a/src/Hankel.jl +++ b/src/Hankel.jl @@ -12,6 +12,6 @@ const J₀₀ = besselj(0, 0) include("utils.jl") include("qdht.jl") -include("diffrules.jl") +include("chainrules.jl") end diff --git a/src/diffrules.jl b/src/chainrules.jl similarity index 81% rename from src/diffrules.jl rename to src/chainrules.jl index aa38797..41405f5 100644 --- a/src/diffrules.jl +++ b/src/chainrules.jl @@ -14,20 +14,20 @@ function ChainRulesCore.rrule(::Type{T}, args...; kwargs...) where {T<:QDHT} end ## rules for fwd/rev transform -ChainRulesCore.frule((_, _, ΔA), ::typeof(*), Q::QDHT, A) = (Q * A, Q * ΔA) -ChainRulesCore.frule((_, _, ΔA), ::typeof(\), Q::QDHT, A) = (Q \ A, Q \ ΔA) +ChainRulesCore.frule((_, _, ΔA), ::typeof(*), Q::QDHT, A) = (Q * A, Q * unthunk(ΔA)) +ChainRulesCore.frule((_, _, ΔA), ::typeof(\), Q::QDHT, A) = (Q \ A, Q \ unthunk(ΔA)) function ChainRulesCore.frule((_, ΔY, _, ΔA), ::typeof(mul!), Y, Q::QDHT, A) - return mul!(Y, Q, A), mul!(ΔY, Q, ΔA) + return mul!(Y, Q, A), mul!(unthunk(ΔY), Q, unthunk(ΔA)) end function ChainRulesCore.frule((_, ΔY, _, ΔA), ::typeof(ldiv!), Y, Q::QDHT, A) - return ldiv!(Y, Q, A), ldiv!(ΔY, Q, ΔA) + return ldiv!(Y, Q, A), ldiv!(unthunk(ΔY), Q, unthunk(ΔA)) end function ChainRulesCore.rrule(::typeof(*), Q::QDHT, A) Y = Q * A function mul_pullback(ΔY) ∂Q = NoTangent() - ∂A = @thunk _mul_back(ΔY, Q, A, Q.scaleRK) + ∂A = _mul_back(unthunk(ΔY), Q, A, Q.scaleRK) return NoTangent(), ∂Q, ∂A end return Y, mul_pullback @@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::typeof(\), Q::QDHT, A) Y = Q \ A function ldiv_pullback(ΔY) ∂Q = NoTangent() - ∂A = @thunk _mul_back(ΔY, Q, A, inv(Q.scaleRK)) + ∂A = _mul_back(unthunk(ΔY), Q, A, inv(Q.scaleRK)) return NoTangent(), ∂Q, ∂A end return Y, ldiv_pullback @@ -53,16 +53,16 @@ end ## rules for integrateR/integrateK function ChainRulesCore.frule((_, ΔA, _), ::typeof(integrateR), A, Q::QDHT; kwargs...) - return integrateR(A, Q; kwargs...), integrateR(ΔA, Q; kwargs...) + return integrateR(A, Q; kwargs...), integrateR(unthunk(ΔA), Q; kwargs...) end function ChainRulesCore.frule((_, ΔA, _), ::typeof(integrateK), A, Q::QDHT; kwargs...) - return integrateK(A, Q; kwargs...), integrateK(ΔA, Q; kwargs...) + return integrateK(A, Q; kwargs...), integrateK(unthunk(ΔA), Q; kwargs...) end function ChainRulesCore.rrule(::typeof(integrateR), A, Q::QDHT; dim = 1) function integrateR_pullback(ΔΩ) - ∂A = @thunk _integrateRK_back(ΔΩ, A, Q.scaleR; dim = dim) + ∂A = _integrateRK_back(unthunk(ΔΩ), A, Q.scaleR; dim = dim) return NoTangent(), ∂A, NoTangent() end return integrateR(A, Q; dim = dim), integrateR_pullback @@ -70,7 +70,7 @@ end function ChainRulesCore.rrule(::typeof(integrateK), A, Q::QDHT; dim = 1) function integrateK_pullback(ΔΩ) - ∂A = @thunk _integrateRK_back(ΔΩ, A, Q.scaleK; dim = dim) + ∂A = _integrateRK_back(unthunk(ΔΩ), A, Q.scaleK; dim = dim) return NoTangent(), ∂A, NoTangent() end return integrateK(A, Q; dim = dim), integrateK_pullback diff --git a/test/Project.toml b/test/Project.toml index f8d171d..b0864a8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -9,9 +8,8 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -ChainRulesCore = "0.9.44, 0.10" -ChainRulesTestUtils = "0.7.8" -FiniteDifferences = "0.12" +ChainRulesCore = "0.9.44, 0.10, 1" +ChainRulesTestUtils = "0.7.8, 1" HCubature = "1.4" SpecialFunctions = "0.10, 1" julia = "1" diff --git a/test/diffrules.jl b/test/chainrules.jl similarity index 100% rename from test/diffrules.jl rename to test/chainrules.jl diff --git a/test/runtests.jl b/test/runtests.jl index c65c282..07eaa12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,6 @@ using Hankel import LinearAlgebra: diagm, mul!, ldiv! import SpecialFunctions: besseli, besselix, besselj import HCubature: hquadrature -using FiniteDifferences using Random using ChainRulesCore using ChainRulesTestUtils @@ -95,4 +94,4 @@ end end include("qdht.jl") -include("diffrules.jl") +include("chainrules.jl")