@@ -14,20 +14,20 @@ function ChainRulesCore.rrule(::Type{T}, args...; kwargs...) where {T<:QDHT}
14
14
end
15
15
16
16
# # rules for fwd/rev transform
17
- ChainRulesCore. frule ((_, _, ΔA), :: typeof (* ), Q:: QDHT , A) = (Q * A, Q * ΔA )
18
- ChainRulesCore. frule ((_, _, ΔA), :: typeof (\ ), Q:: QDHT , A) = (Q \ A, Q \ ΔA )
17
+ ChainRulesCore. frule ((_, _, ΔA), :: typeof (* ), Q:: QDHT , A) = (Q * A, Q * unthunk (ΔA) )
18
+ ChainRulesCore. frule ((_, _, ΔA), :: typeof (\ ), Q:: QDHT , A) = (Q \ A, Q \ unthunk (ΔA) )
19
19
function ChainRulesCore. frule ((_, ΔY, _, ΔA), :: typeof (mul!), Y, Q:: QDHT , A)
20
- return mul! (Y, Q, A), mul! (ΔY , Q, ΔA )
20
+ return mul! (Y, Q, A), mul! (unthunk (ΔY) , Q, unthunk (ΔA) )
21
21
end
22
22
function ChainRulesCore. frule ((_, ΔY, _, ΔA), :: typeof (ldiv!), Y, Q:: QDHT , A)
23
- return ldiv! (Y, Q, A), ldiv! (ΔY , Q, ΔA )
23
+ return ldiv! (Y, Q, A), ldiv! (unthunk (ΔY) , Q, unthunk (ΔA) )
24
24
end
25
25
26
26
function ChainRulesCore. rrule (:: typeof (* ), Q:: QDHT , A)
27
27
Y = Q * A
28
28
function mul_pullback (ΔY)
29
29
∂Q = NoTangent ()
30
- ∂A = @thunk _mul_back (ΔY , Q, A, Q. scaleRK)
30
+ ∂A = _mul_back (unthunk (ΔY) , Q, A, Q. scaleRK)
31
31
return NoTangent (), ∂Q, ∂A
32
32
end
33
33
return Y, mul_pullback
@@ -37,7 +37,7 @@ function ChainRulesCore.rrule(::typeof(\), Q::QDHT, A)
37
37
Y = Q \ A
38
38
function ldiv_pullback (ΔY)
39
39
∂Q = NoTangent ()
40
- ∂A = @thunk _mul_back (ΔY , Q, A, inv (Q. scaleRK))
40
+ ∂A = _mul_back (unthunk (ΔY) , Q, A, inv (Q. scaleRK))
41
41
return NoTangent (), ∂Q, ∂A
42
42
end
43
43
return Y, ldiv_pullback
53
53
54
54
# # rules for integrateR/integrateK
55
55
function ChainRulesCore. frule ((_, ΔA, _), :: typeof (integrateR), A, Q:: QDHT ; kwargs... )
56
- return integrateR (A, Q; kwargs... ), integrateR (ΔA , Q; kwargs... )
56
+ return integrateR (A, Q; kwargs... ), integrateR (unthunk (ΔA) , Q; kwargs... )
57
57
end
58
58
59
59
function ChainRulesCore. frule ((_, ΔA, _), :: typeof (integrateK), A, Q:: QDHT ; kwargs... )
60
- return integrateK (A, Q; kwargs... ), integrateK (ΔA , Q; kwargs... )
60
+ return integrateK (A, Q; kwargs... ), integrateK (unthunk (ΔA) , Q; kwargs... )
61
61
end
62
62
63
63
function ChainRulesCore. rrule (:: typeof (integrateR), A, Q:: QDHT ; dim = 1 )
64
64
function integrateR_pullback (ΔΩ)
65
- ∂A = @thunk _integrateRK_back (ΔΩ , A, Q. scaleR; dim = dim)
65
+ ∂A = _integrateRK_back (unthunk (ΔΩ) , A, Q. scaleR; dim = dim)
66
66
return NoTangent (), ∂A, NoTangent ()
67
67
end
68
68
return integrateR (A, Q; dim = dim), integrateR_pullback
69
69
end
70
70
71
71
function ChainRulesCore. rrule (:: typeof (integrateK), A, Q:: QDHT ; dim = 1 )
72
72
function integrateK_pullback (ΔΩ)
73
- ∂A = @thunk _integrateRK_back (ΔΩ , A, Q. scaleK; dim = dim)
73
+ ∂A = _integrateRK_back (unthunk (ΔΩ) , A, Q. scaleK; dim = dim)
74
74
return NoTangent (), ∂A, NoTangent ()
75
75
end
76
76
return integrateK (A, Q; dim = dim), integrateK_pullback
0 commit comments