Skip to content

Commit 9e33eb5

Browse files
authored
Support ChainRulesCore v1 and ChainRulesTestUtils v1 (#27)
* Rename diffrules to chainrules * Increment supported version numbers * Drop test dependency on FD * Use ChainRulesTestUtils * Increment version number * Thunk and unthunk correctly * Revert "Use ChainRulesTestUtils" This reverts commit 5b1cf3d.
1 parent a0789b6 commit 9e33eb5

File tree

6 files changed

+16
-19
lines changed

6 files changed

+16
-19
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Hankel"
22
uuid = "74863788-d124-456e-a676-9b76578dd39e"
33
authors = ["chrisbrahms <[email protected]>"]
4-
version = "0.5.5"
4+
version = "0.5.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -11,7 +11,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
1111
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1212

1313
[compat]
14-
ChainRulesCore = "0.9.44, 0.10"
14+
ChainRulesCore = "0.9.44, 0.10, 1"
1515
GSL = "0.6, 1"
1616
Roots = "0.8, 1"
1717
SpecialFunctions = "0.10, 1"

src/Hankel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ const J₀₀ = besselj(0, 0)
1212

1313
include("utils.jl")
1414
include("qdht.jl")
15-
include("diffrules.jl")
15+
include("chainrules.jl")
1616

1717
end

src/diffrules.jl renamed to src/chainrules.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@ function ChainRulesCore.rrule(::Type{T}, args...; kwargs...) where {T<:QDHT}
1414
end
1515

1616
## rules for fwd/rev transform
17-
ChainRulesCore.frule((_, _, ΔA), ::typeof(*), Q::QDHT, A) = (Q * A, Q * ΔA)
18-
ChainRulesCore.frule((_, _, ΔA), ::typeof(\), Q::QDHT, A) = (Q \ A, Q \ ΔA)
17+
ChainRulesCore.frule((_, _, ΔA), ::typeof(*), Q::QDHT, A) = (Q * A, Q * unthunk(ΔA))
18+
ChainRulesCore.frule((_, _, ΔA), ::typeof(\), Q::QDHT, A) = (Q \ A, Q \ unthunk(ΔA))
1919
function ChainRulesCore.frule((_, ΔY, _, ΔA), ::typeof(mul!), Y, Q::QDHT, A)
20-
return mul!(Y, Q, A), mul!(ΔY, Q, ΔA)
20+
return mul!(Y, Q, A), mul!(unthunk(ΔY), Q, unthunk(ΔA))
2121
end
2222
function ChainRulesCore.frule((_, ΔY, _, ΔA), ::typeof(ldiv!), Y, Q::QDHT, A)
23-
return ldiv!(Y, Q, A), ldiv!(ΔY, Q, ΔA)
23+
return ldiv!(Y, Q, A), ldiv!(unthunk(ΔY), Q, unthunk(ΔA))
2424
end
2525

2626
function ChainRulesCore.rrule(::typeof(*), Q::QDHT, A)
2727
Y = Q * A
2828
function mul_pullback(ΔY)
2929
∂Q = NoTangent()
30-
∂A = @thunk _mul_back(ΔY, Q, A, Q.scaleRK)
30+
∂A = _mul_back(unthunk(ΔY), Q, A, Q.scaleRK)
3131
return NoTangent(), ∂Q, ∂A
3232
end
3333
return Y, mul_pullback
@@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::typeof(\), Q::QDHT, A)
3737
Y = Q \ A
3838
function ldiv_pullback(ΔY)
3939
∂Q = NoTangent()
40-
∂A = @thunk _mul_back(ΔY, Q, A, inv(Q.scaleRK))
40+
∂A = _mul_back(unthunk(ΔY), Q, A, inv(Q.scaleRK))
4141
return NoTangent(), ∂Q, ∂A
4242
end
4343
return Y, ldiv_pullback
@@ -53,24 +53,24 @@ end
5353

5454
## rules for integrateR/integrateK
5555
function ChainRulesCore.frule((_, ΔA, _), ::typeof(integrateR), A, Q::QDHT; kwargs...)
56-
return integrateR(A, Q; kwargs...), integrateR(ΔA, Q; kwargs...)
56+
return integrateR(A, Q; kwargs...), integrateR(unthunk(ΔA), Q; kwargs...)
5757
end
5858

5959
function ChainRulesCore.frule((_, ΔA, _), ::typeof(integrateK), A, Q::QDHT; kwargs...)
60-
return integrateK(A, Q; kwargs...), integrateK(ΔA, Q; kwargs...)
60+
return integrateK(A, Q; kwargs...), integrateK(unthunk(ΔA), Q; kwargs...)
6161
end
6262

6363
function ChainRulesCore.rrule(::typeof(integrateR), A, Q::QDHT; dim = 1)
6464
function integrateR_pullback(ΔΩ)
65-
∂A = @thunk _integrateRK_back(ΔΩ, A, Q.scaleR; dim = dim)
65+
∂A = _integrateRK_back(unthunk(ΔΩ), A, Q.scaleR; dim = dim)
6666
return NoTangent(), ∂A, NoTangent()
6767
end
6868
return integrateR(A, Q; dim = dim), integrateR_pullback
6969
end
7070

7171
function ChainRulesCore.rrule(::typeof(integrateK), A, Q::QDHT; dim = 1)
7272
function integrateK_pullback(ΔΩ)
73-
∂A = @thunk _integrateRK_back(ΔΩ, A, Q.scaleK; dim = dim)
73+
∂A = _integrateRK_back(unthunk(ΔΩ), A, Q.scaleK; dim = dim)
7474
return NoTangent(), ∂A, NoTangent()
7575
end
7676
return integrateK(A, Q; dim = dim), integrateK_pullback

test/Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
[deps]
22
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
33
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
4-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
54
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
65
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
76
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
87
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
98
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
109

1110
[compat]
12-
ChainRulesCore = "0.9.44, 0.10"
13-
ChainRulesTestUtils = "0.7.8"
14-
FiniteDifferences = "0.12"
11+
ChainRulesCore = "0.9.44, 0.10, 1"
12+
ChainRulesTestUtils = "0.7.8, 1"
1513
HCubature = "1.4"
1614
SpecialFunctions = "0.10, 1"
1715
julia = "1"
File renamed without changes.

test/runtests.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ using Hankel
33
import LinearAlgebra: diagm, mul!, ldiv!
44
import SpecialFunctions: besseli, besselix, besselj
55
import HCubature: hquadrature
6-
using FiniteDifferences
76
using Random
87
using ChainRulesCore
98
using ChainRulesTestUtils
@@ -95,4 +94,4 @@ end
9594
end
9695

9796
include("qdht.jl")
98-
include("diffrules.jl")
97+
include("chainrules.jl")

0 commit comments

Comments
 (0)