-
Notifications
You must be signed in to change notification settings - Fork 6
WIP: Add ChainRules #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #21 +/- ##
==========================================
+ Coverage 70.40% 77.86% +7.45%
==========================================
Files 4 5 +1
Lines 98 131 +33
==========================================
+ Hits 69 102 +33
Misses 29 29
Continue to review full report at Codecov.
|
e67fe36
to
5dce0e9
Compare
what's the error for 2. ? I will try to take a look at the end of the day |
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
It is numerically very very wrong. |
Not sure why the passing the Tangent was failing (the math looked fine) but solved it by creating an intermediate Woodbury. Also putting the I still don't know why the following fails though: julia> W = WoodburyPDMat(rand(3,2), Diagonal(rand(2,)), Diagonal(rand(3,)))
3×3 WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}:
1.26364 0.194865 0.14236
0.194865 0.553561 0.0738048
0.14236 0.0738048 0.366308
julia> T = Tangent{WoodburyPDMat}(;A=W.A, D=W.D, S=W.S)
Tangent{WoodburyPDMat}(A = [0.6529299263287578 0.9705937318111584; 0.2742328923999924 0.5824550169874951; 0.14327303956778725 0.6355464664765937], D = [0.6900525901684949 0.0; 0.0 0.12613565656887138], S = [0.8506296051551665 0.0 0.0; 0.0 0.4588742330641957 0.0; 0.0 0.0 0.3011946555949865])
julia> test_rrule(*, 2.0, W ⊢ W; output_tangent=W)
test_rrule: * on Float64,WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}: Test Failed at /Users/mzgubic/.julia/packages/ChainRulesTestUtils/8380y/src/check_result.jl:24
Expression: isapprox(actual, expected; kwargs...)
Evaluated: isapprox(2.1647628223399655, 1.5169372885519636; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
[1] test_approx(actual::Float64, expected::Float64, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/check_result.jl:24
[2] macro expansion
@ ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:238 [inlined]
[3] macro expansion
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
[4] test_rrule(::ChainRulesTestUtils.ADviaRuleConfig, ::typeof(*), ::Float64, ::Vararg{Any, N} where N; output_tangent::WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}, check_thunked_output_tangent::Bool, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:194
Test Summary: | Pass Fail Total
test_rrule: * on Float64,WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}} | 8 1 9
ERROR: Some tests did not pass: 8 passed, 1 failed, 0 errored, 0 broken. Also, I don't think we have fully figured out our story in testing arbitrary tangents (like passing a Matrix instead of the Woodbury as |
OK, I undestand a bit more about what's going on now: Consider this snippet - (the tests pass at least locally (note I am using the commit prior to the one just added where the woodburyPDMat constructor is added to the pullback):
ProjectTo is pushing the tangent W̄ into it's constituents, A, D and S. It is not the case that EDIT: Added in some of the variables that were implicit. Removed the call to the rrule which at the moment won't work. This is just testing the derivatives. |
ca45e00
to
f6db1da
Compare
I'm not sure what is intended is possible in the present set-up. If we consider the pullback and the projection:
We want the pullback to support Some options to consider:
Though I've probably misunderstood something somewhere. |
I was thinking about the last option. I only think it is possible to go from Woodbury to a vector, but not back (as far as I know it is not possible to "factorise" an arbitrary dense matrix to a Woodbury?). It might still solve our problems, because in some cases only going to a vector is required (and not going back). "Some cases" is I think the case where we vectorise the tangent of the primal output. |
I've cleaned up this MR now - it should be ready to go. As there were several discussion points previously, I'll give an overview:
Will take out of WIP. |
ChainRulesCore.ProjectTo{T}(; fields...) | ||
end | ||
|
||
function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand what this does. It looks like we are trying to project an arbitrary matrix onto a Woodbury, which as I understand is not possible in case X is not positive semi-definite? (In that case we might explicitly check for PSD and throw an error?)
If all is well and I just don't understand something, could we add a few lines of comments that explain what is going on and maybe link a reference for the math?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trying to project an arbitrary matrix onto a Woodbury
Nope - we are projecting an arbitrary matrix onto a Woodbury tangent. I.e. instead of passing around W̄
, we are passing around Ā, D̄, S̄
The matrix here can be anything. Project here is encoding the rules on how the individual components of the woodbury, i.e. from the definition W = ADA' + S -> dW = ADdA' + AdDA' + dS
=> Ā = 2ADW̄, D̄=AW̄A', S̄ = W̄ with their appropriate projectiosn down to e.g. Diagonal. X̄ here is W̄.
Ideally, these would only be in the constructor for the Woodbury, however, due to the way FD is set-up, it seems we always have to project down to these components. THe woodbury is a lazy low-rank matrix and because of this any woodbury primal has number of elements 2*(A1 + A2) where (A1, A2) is the size of the matrix A. And the tangent has to be the same size (i.e. number of elements), i.e. their to_vec
representation must equal in length
This is why when testing with a Matrix cotangent I couldn't use CRTU and had to do a manual call to FD and densify the primal here
If need be, I believe we can only ever return the natural differential Matrix (i.e. W̄ ) as the outut of any abitrary rrule with a Woodbury argument as is done here but as I say I believe that FD will not like this in some scenarios.
The constraints imposed here are, as I understand them are that the primal output and output cotangent need to be of the same length, and that the input args and input adjoints are also all of the same length to each other (where length is the to_vec representation).
I'll add comments, but first I want to make sure that that makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I am actually not sure what we should do here, I've opened JuliaDiff/ChainRulesCore.jl#442
with the PositiveReal
example you have showed me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm reading that thread.
The use case here aside, Is projecting onto a Tangent
not an intended use of ProjectTo
?
I'm sort of getting that impression from reading these threads, and it is not one that I had realised, if that is the case. I now note that nothing in the projection.jl in CRC does map to Tangent for some arbitrary struct.
I had thought the point of Project
was to take a differential and map it appropriately onto the correct differential representation, in this case it being a Tangent{WoodburyPDMat}
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can map to Tangent
because Tangent
is a differential type. The differences is that
ProjectTo(w::Tangent{Woodbury})
is fine, but
ProjectTo(w::Woodbury)
is not, and we are doing the second here. (Despite the fact that the projection actually returns the Tangent{Woodbury}
, which is inconsistent since we said we are projecting onto a Woodbury
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this relate to the proposal in JuliaDiff/ChainRulesCore.jl#449?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Will thought that it is possible to find the inverse of Woodbury densification (for the subspace of dense matrices that can come from a Woodbury) but I may have misinterpreted that.
This seems to be a weird edge case @willtebbutt. If I understand correctly, in this case we want the structural rather than the natural differential inside the pullback written in the rule (rather than the premise in the 449 which expects AD to prefer structural tangents, and rule writers to prefer natural tangents)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proposal in 449 would basically replace ProjectTo
/ change its semantics to eat a natural tangent / Matrix
and produce a structural tangent. You obtain the new thing by implementing the pullback for collect
, and that does everything you need.
Re inverse of densification -- I found that hard to implement for WoodburyPDMat
-- see here but the pullback for collect
(pullback_of_destructure
-- I wanted to give it a different name) is totally fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately, what's the rationale for implementing ProjectTo
? Is it to make it possible to do *(A::Matrix, B::WoodburyPDMat)
? If so, would it make more sense / be easier to implement a specialised method for the primal (since the generic fallback will be horribly slow) that is differentiable, and @opt_out
of the generic rrule to let AD do the work of deriving the rule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately, what's the rationale for implementing ProjectTo
Originally, I was using the ProjectTo
here to project the natural tangent onto it's structural tangent - i.e. for a given W̄ (the natural tangent, as I understand the semantics), to it's structural (Ā, D̄, S̄), but it turned out this wasn't permitted under the semantics of ProjectTo - which what led to the 'allow projectto onto tangent' issue . I paused looking at this MR at that point, but saw your RFC being posted and pinged this thread on how these two related. As far as I understand the resolution to #442 in CRC it's added an identity, but I don't think that changes anything w.r.t using it as a tangent, so I still wouldn't know how to continue this.
If I understand your RFC, the way I was using the projectto here was more in line with what you were considering.
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat} | ||
Ā = @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S)) | ||
B̄ = @thunk proj.B(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * A, S = Ȳ.S * A)) | ||
return (NoTangent(), Ā, B̄) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we do something similar to the forward pass? like
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat} | |
Ā = @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S)) | |
B̄ = @thunk proj.B(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * A, S = Ȳ.S * A)) | |
return (NoTangent(), Ā, B̄) | |
end | |
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat} | |
df, db, da = _times_pullback(Ȳ, B, A, (A=proj.B, B=proj.A)) | |
return df, db, da | |
end |
(I don't think we have to worry about adjointing a Real)
) | ||
end | ||
|
||
@testset "*(Matrix-Woodbury)" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see this rule anywhere, does this work because ChainRules are loaded? might be worth adding a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I think these are because Zygote is loaded.
I'll add ChainRules explicitely
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
Co-authored-by: Miha Zgubic <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for jumping in late. Saw this linked in 449.
|
||
# Rule for Woodbury * Real. | ||
# Ignoring Complex version for now. | ||
function ChainRulesCore.rrule(::typeof(*), A::WoodburyPDMat, B::Real) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to hop in late here, but why do we need a rule for this, as opposed to just opting out?
In particular, our implementation of this function ought to be straightforwardly differentiable by Zygote et al:
*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
*(c::Real, a::WoodburyPDMat) = a * c
Is it not sufficient just to @opt_out
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for *(H::Diagonal, S::WoodburyPDMat)
. If you remember we had issues with that and had to do WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D)
. This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.
I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for *(H::Diagonal, S::WoodburyPDMat). If you remember we had issues with that and had to do WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D). This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.
Interesting -- again, I'd suggest just opting out here 🤷 Or am I missing the point?
I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?
No. In simple cases like this there's no reason that Zygote can't achieve optimal performance. If you have control flow (either via a for-loop or an if-else-end
block) Zygote will generally struggle, but otherwise you're definitely better off letting Zygote figure stuff out.
My experience has consistently been that if you have the choice between writing a rule and writing AD-friendly code, always pick the latter. The fact that opting-out is necessary is unfortunate, but it's the compromise we have for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
between writing a rule and writing AD-friendly code, always pick the latter.
Interesting - thanks. I hadn't considered this angle, and that workarounds like WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D)
are advised over setting up a rule.
ChainRulesCore.ProjectTo{T}(; fields...) | ||
end | ||
|
||
function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The proposal in 449 would basically replace ProjectTo
/ change its semantics to eat a natural tangent / Matrix
and produce a structural tangent. You obtain the new thing by implementing the pullback for collect
, and that does everything you need.
Re inverse of densification -- I found that hard to implement for WoodburyPDMat
-- see here but the pullback for collect
(pullback_of_destructure
-- I wanted to give it a different name) is totally fine.
ChainRulesCore.ProjectTo{T}(; fields...) | ||
end | ||
|
||
function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately, what's the rationale for implementing ProjectTo
? Is it to make it possible to do *(A::Matrix, B::WoodburyPDMat)
? If so, would it make more sense / be easier to implement a specialised method for the primal (since the generic fallback will be horribly slow) that is differentiable, and @opt_out
of the generic rrule to let AD do the work of deriving the rule?
I'm just going to close this MR. I had largely set-out to try and understand CR through implementing rules for the woodbury (and follow up with a resolution to the Woodbury * diagonal AD workaround that was neccessary) with the presumption that the rules would be more performant (i neever got to benchmarking these) but this seems the wrong view. There may be some follow-ups (such as opt-outs and supporting CRC 1) but it seems best to do that with a clean slate. |
Adds support for ChainRules 1.0.
This tests
and adds in some constructor tests too.