-
Notifications
You must be signed in to change notification settings - Fork 6
WIP: Add ChainRules #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d2c7537
c586180
10bc061
13054b4
d149676
49f1ee6
89947f6
e81dff8
5dce0e9
2e31400
18a4b99
5f91685
bbfe9b2
c79f177
49f6795
64e7780
b3ff0c2
ca45e00
f6db1da
ed3ee75
0edc740
b6e94f8
dde2b46
73f6bbc
251d183
31cf7e1
a9900c8
f75ed2e
955b2c1
2e7babe
e1bd5f9
9431ead
8ba6f0d
ff8d903
eb23510
9535406
a9cdd38
eb194c5
c7e98c1
8ce1498
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,69 @@ | ||||||||||||||||||||
@non_differentiable validate_woodbury_arguments(A, D, S) | ||||||||||||||||||||
|
||||||||||||||||||||
# Rule for Woodbury * Real. | ||||||||||||||||||||
# Ignoring Complex version for now. | ||||||||||||||||||||
function ChainRulesCore.rrule(::typeof(*), A::WoodburyPDMat, B::Real) | ||||||||||||||||||||
project_A = ProjectTo(A) | ||||||||||||||||||||
project_B = ProjectTo(B) | ||||||||||||||||||||
primal = A * B | ||||||||||||||||||||
times_pullback(ȳ) = _times_pullback(ȳ, primal, A, B, (;A=project_A, B=project_B)) | ||||||||||||||||||||
return primal, times_pullback | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat) | ||||||||||||||||||||
project_A = ProjectTo(A) | ||||||||||||||||||||
project_B = ProjectTo(B) | ||||||||||||||||||||
primal = A * B | ||||||||||||||||||||
times_pullback(ȳ) = _times_pullback(ȳ, primal, A, B, (;A=project_A, B=project_B)) | ||||||||||||||||||||
return primal, times_pullback | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
_times_pullback(ȳ::AbstractThunk, primal, A, B, proj) = _times_pullback(unthunk(ȳ), primal, A, B, proj) | ||||||||||||||||||||
# If the cotangent is a Matrix we first need to project down, otherwise ignore | ||||||||||||||||||||
_times_pullback(Ȳ::AbstractMatrix, primal, A, B, proj) = _times_pullback(ProjectTo(primal)(Ȳ), A, B, proj) | ||||||||||||||||||||
_times_pullback(ȳ::Tangent, primal, A, B, proj) = _times_pullback(ȳ, A, B, proj) | ||||||||||||||||||||
|
||||||||||||||||||||
function _times_pullback(Ȳ::Tangent, A::T, B::Real, proj) where {T<:WoodburyPDMat} | ||||||||||||||||||||
Ā = @thunk proj.A(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * B', S = Ȳ.S * B')) | ||||||||||||||||||||
B̄ = @thunk proj.B(dot(Ȳ.D, A.D) + dot(Ȳ.S, A.S)) | ||||||||||||||||||||
return (NoTangent(), Ā, B̄) | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat} | ||||||||||||||||||||
Ā = @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S)) | ||||||||||||||||||||
B̄ = @thunk proj.B(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * A, S = Ȳ.S * A)) | ||||||||||||||||||||
return (NoTangent(), Ā, B̄) | ||||||||||||||||||||
end | ||||||||||||||||||||
Comment on lines
+32
to
+36
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we do something similar to the forward pass? like
Suggested change
(I don't think we have to worry about adjointing a Real) |
||||||||||||||||||||
|
||||||||||||||||||||
# Composite pullbacks | ||||||||||||||||||||
function ChainRulesCore.rrule( | ||||||||||||||||||||
::Type{T}, | ||||||||||||||||||||
A::AbstractMatrix, | ||||||||||||||||||||
D::Diagonal, | ||||||||||||||||||||
S::Diagonal, | ||||||||||||||||||||
) where {T<:WoodburyPDMat} | ||||||||||||||||||||
return WoodburyPDMat(A, D, S), X̄ -> WoodburyPDMat_pullback(X̄, A, D, S) | ||||||||||||||||||||
end | ||||||||||||||||||||
WoodburyPDMat_pullback(X̄::Tangent, A, D, S) = (NoTangent(), X̄.A, X̄.D, X̄.S) | ||||||||||||||||||||
WoodburyPDMat_pullback(X̄::AbstractThunk, A, D, S) = WoodburyPDMat_pullback(unthunk(X̄), A, D, S) | ||||||||||||||||||||
|
||||||||||||||||||||
function ChainRulesCore.ProjectTo(W::T) where {T<:WoodburyPDMat} | ||||||||||||||||||||
fields = (A = W.A, D = W.D, S = W.S) | ||||||||||||||||||||
ChainRulesCore.ProjectTo{T}(; fields...) | ||||||||||||||||||||
end | ||||||||||||||||||||
|
||||||||||||||||||||
# | ||||||||||||||||||||
# Project the differential onto the Tangent{WoodburyPDMat}. | ||||||||||||||||||||
# This essentially computes the pullbacks for the components of the Woodbury | ||||||||||||||||||||
# i.e. from the definition: W = ADA' + S | ||||||||||||||||||||
# dW = ADdA' + AdDA' + dS | ||||||||||||||||||||
# => Ā = 2ADW̄, D̄=AW̄A', S̄ = W̄ | ||||||||||||||||||||
# More precise formulation available e.g. here: | ||||||||||||||||||||
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf | ||||||||||||||||||||
function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat} | ||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand what this does. It looks like we are trying to project an arbitrary matrix onto a Woodbury, which as I understand is not possible in case X is not positive semi-definite? (In that case we might explicitly check for PSD and throw an error?) If all is well and I just don't understand something, could we add a few lines of comments that explain what is going on and maybe link a reference for the math? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Nope - we are projecting an arbitrary matrix onto a Woodbury tangent. I.e. instead of passing around W̄ The matrix here can be anything. Project here is encoding the rules on how the individual components of the woodbury, i.e. from the definition Ideally, these would only be in the constructor for the Woodbury, however, due to the way FD is set-up, it seems we always have to project down to these components. THe woodbury is a lazy low-rank matrix and because of this any woodbury primal has number of elements 2*(A1 + A2) where (A1, A2) is the size of the matrix A. And the tangent has to be the same size (i.e. number of elements), i.e. their This is why when testing with a Matrix cotangent I couldn't use CRTU and had to do a manual call to FD and densify the primal here If need be, I believe we can only ever return the natural differential Matrix (i.e. W̄ ) as the outut of any abitrary rrule with a Woodbury argument as is done here but as I say I believe that FD will not like this in some scenarios. The constraints imposed here are, as I understand them are that the primal output and output cotangent need to be of the same length, and that the input args and input adjoints are also all of the same length to each other (where length is the to_vec representation). I'll add comments, but first I want to make sure that that makes sense There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I am actually not sure what we should do here, I've opened JuliaDiff/ChainRulesCore.jl#442 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm reading that thread. The use case here aside, Is projecting onto a I had thought the point of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, we can map to ProjectTo(w::Tangent{Woodbury}) is fine, but ProjectTo(w::Woodbury) is not, and we are doing the second here. (Despite the fact that the projection actually returns the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this relate to the proposal in JuliaDiff/ChainRulesCore.jl#449? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Will thought that it is possible to find the inverse of Woodbury densification (for the subspace of dense matrices that can come from a Woodbury) but I may have misinterpreted that. This seems to be a weird edge case @willtebbutt. If I understand correctly, in this case we want the structural rather than the natural differential inside the pullback written in the rule (rather than the premise in the 449 which expects AD to prefer structural tangents, and rule writers to prefer natural tangents) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The proposal in 449 would basically replace Re inverse of densification -- I found that hard to implement for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separately, what's the rationale for implementing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Originally, I was using the If I understand your RFC, the way I was using the projectto here was more in line with what you were considering. |
||||||||||||||||||||
Ā = ProjectTo(project.A)((X̄ + X̄') * (project.A * project.D)) | ||||||||||||||||||||
D̄ = ProjectTo(project.D)(project.A' * (X̄) * project.A) | ||||||||||||||||||||
S̄ = ProjectTo(project.S)(X̄) | ||||||||||||||||||||
return Tangent{WoodburyPDMat}(; A = Ā, D = D̄, S = S̄) | ||||||||||||||||||||
end | ||||||||||||||||||||
(project::ProjectTo{T})(W::Tangent) where {T<:WoodburyPDMat} = W |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
@testset "ChainRules" begin | ||
|
||
A = randn(4, 2) | ||
D = Diagonal(randn(2).^2 .+ 1) | ||
S = Diagonal(randn(4).^2 .+ 1) | ||
|
||
W = WoodburyPDMat(A, D, S) | ||
R = 2.0 | ||
Dmat = Diagonal(rand(4,)) | ||
|
||
x = randn(size(A, 1)) | ||
|
||
@testset "Constructors" begin | ||
test_rrule(WoodburyPDMat, W.A, W.D, W.S) | ||
# This is a gradient, should be able to deal with negative elements (does not have to be PSD like Woodbury itself) | ||
test_rrule(WoodburyPDMat, W.A, W.D, W.S; | ||
output_tangent=Tangent{WoodburyPDMat}(; | ||
A = rand(4,2), D = Diagonal(-1 * rand(2,)), S = Diagonal(-1 * rand(4,))) | ||
) | ||
end | ||
|
||
# The rrules already in ChainRules are sufficient for these to work. We just test an example here. | ||
@testset "*(Matrix-Woodbury)" begin | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see this rule anywhere, does this work because ChainRules are loaded? might be worth adding a comment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I think these are because Zygote is loaded. I'll add ChainRules explicitely |
||
test_rrule(*, Dmat, W) | ||
test_rrule(*, W, Dmat) | ||
test_rrule(*, rand(4,4), W) | ||
end | ||
|
||
@testset "*(Woodbury-Real)" begin | ||
test_rrule(*, W, R) | ||
test_rrule(*, R, W) | ||
|
||
# We can't test test_rrule(*, R, W; output_tangent = rand(size(W)...)) i.e. with a Matrix because | ||
# FD requires the primal and tangent to be the same size. However, we can just call FD directly and overload | ||
# the primal computation to return a Matrix: | ||
@testset "Matrix CoTangent" begin | ||
res, pb = ChainRulesCore.rrule(*, R, W) | ||
output_tangent = rand(size(W)...) | ||
f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> Matrix(*(x...)), output_tangent, (R, W))[1] | ||
@test unthunk(pb(output_tangent)[3]).A ≈ f_jvp[2].A | ||
@test unthunk(pb(output_tangent)[3]).D ≈ f_jvp[2].D | ||
@test unthunk(pb(output_tangent)[3]).S ≈ f_jvp[2].S | ||
@test unthunk(pb(output_tangent)[2]) ≈ f_jvp[1] | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to hop in late here, but why do we need a rule for this, as opposed to just opting out?
In particular, our implementation of this function ought to be straightforwardly differentiable by Zygote et al:
Is it not sufficient just to
@opt_out
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for
*(H::Diagonal, S::WoodburyPDMat)
. If you remember we had issues with that and had to doWoodburyPDMat(H * S.A, S.D, H * S.S * H' + D)
. This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting -- again, I'd suggest just opting out here 🤷 Or am I missing the point?
No. In simple cases like this there's no reason that Zygote can't achieve optimal performance. If you have control flow (either via a for-loop or an
if-else-end
block) Zygote will generally struggle, but otherwise you're definitely better off letting Zygote figure stuff out.My experience has consistently been that if you have the choice between writing a rule and writing AD-friendly code, always pick the latter. The fact that opting-out is necessary is unfortunate, but it's the compromise we have for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting - thanks. I hadn't considered this angle, and that workarounds like
WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D)
are advised over setting up a rule.