Skip to content

WIP: Add ChainRules #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed

WIP: Add ChainRules #21

wants to merge 40 commits into from

Conversation

AlexRobson
Copy link
Member

@AlexRobson AlexRobson commented Jul 27, 2021

Adds support for ChainRules 1.0.

This tests

*(::AbstractVecOrMat, ::Woodbury)::Matrix
*(::Real, ::Woodbury)::WoodburyPDMat

and adds in some constructor tests too.

@codecov
Copy link

codecov bot commented Jul 27, 2021

Codecov Report

Merging #21 (8ce1498) into master (11955cd) will increase coverage by 7.45%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #21      +/-   ##
==========================================
+ Coverage   70.40%   77.86%   +7.45%     
==========================================
  Files           4        5       +1     
  Lines          98      131      +33     
==========================================
+ Hits           69      102      +33     
  Misses         29       29              
Impacted Files Coverage Δ
src/PDMatsExtras.jl 100.00% <ø> (ø)
src/woodbury_pd_mat.jl 93.75% <ø> (-0.19%) ⬇️
src/chainrules.jl 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 11955cd...8ce1498. Read the comment docs.

@mzgubic
Copy link
Contributor

mzgubic commented Jul 28, 2021

what's the error for 2. ? I will try to take a look at the end of the day

Alex Robson and others added 5 commits July 28, 2021 17:01
@AlexRobson
Copy link
Member Author

what's the error for 2. ? I will try to take a look at the end of the day

It is numerically very very wrong.

@mzgubic
Copy link
Contributor

mzgubic commented Jul 29, 2021

Not sure why the passing the Tangent was failing (the math looked fine) but solved it by creating an intermediate Woodbury. Also putting the times_pullback outside of the rrule somehow solves the inference (we have seen this before, but I don't understand it).

I still don't know why the following fails though:

julia> W = WoodburyPDMat(rand(3,2), Diagonal(rand(2,)), Diagonal(rand(3,)))
3×3 WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}:
 1.26364   0.194865   0.14236
 0.194865  0.553561   0.0738048
 0.14236   0.0738048  0.366308

julia> T = Tangent{WoodburyPDMat}(;A=W.A, D=W.D, S=W.S)
Tangent{WoodburyPDMat}(A = [0.6529299263287578 0.9705937318111584; 0.2742328923999924 0.5824550169874951; 0.14327303956778725 0.6355464664765937], D = [0.6900525901684949 0.0; 0.0 0.12613565656887138], S = [0.8506296051551665 0.0 0.0; 0.0 0.4588742330641957 0.0; 0.0 0.0 0.3011946555949865])

julia> test_rrule(*, 2.0, W  W; output_tangent=W)
test_rrule: * on Float64,WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}: Test Failed at /Users/mzgubic/.julia/packages/ChainRulesTestUtils/8380y/src/check_result.jl:24
  Expression: isapprox(actual, expected; kwargs...)
   Evaluated: isapprox(2.1647628223399655, 1.5169372885519636; rtol = 1.0e-9, atol = 1.0e-9)
Stacktrace:
 [1] test_approx(actual::Float64, expected::Float64, msg::String; kwargs::Base.Iterators.Pairs{Symbol, Float64, Tuple{Symbol, Symbol}, NamedTuple{(:rtol, :atol), Tuple{Float64, Float64}}})
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/check_result.jl:24
 [2] macro expansion
   @ ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:238 [inlined]
 [3] macro expansion
   @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/Test/src/Test.jl:1151 [inlined]
 [4] test_rrule(::ChainRulesTestUtils.ADviaRuleConfig, ::typeof(*), ::Float64, ::Vararg{Any, N} where N; output_tangent::WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}}, check_thunked_output_tangent::Bool, fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, rrule_f::Function, check_inferred::Bool, fkwargs::NamedTuple{(), Tuple{}}, rtol::Float64, atol::Float64, kwargs::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
   @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/8380y/src/testers.jl:194
Test Summary:                                                                                                                            | Pass  Fail  Total
test_rrule: * on Float64,WoodburyPDMat{Float64, Matrix{Float64}, Diagonal{Float64, Vector{Float64}}, Diagonal{Float64, Vector{Float64}}} |    8     1      9
ERROR: Some tests did not pass: 8 passed, 1 failed, 0 errored, 0 broken.

Also, I don't think we have fully figured out our story in testing arbitrary tangents (like passing a Matrix instead of the Woodbury as output_tangent). I will write an issues about this.

@AlexRobson
Copy link
Member Author

AlexRobson commented Jul 29, 2021

OK, I undestand a bit more about what's going on now:

Consider this snippet - (the tests pass at least locally (note I am using the commit prior to the one just added where the woodburyPDMat constructor is added to the pullback):


            primal = R * W
            # Generate the Tangent as ChainRulesTestUtils would do
            ∂primal = rand_tangent(Random.GLOBAL_RNG, collect(primal))
            T = ProjectTo(primal)(∂primal)
            f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> (*(x...)), T, (R, W))[1]

            # Expected
            R̄ = ProjectTo(R)(dot(T, W'))
            W̄ = ProjectTo(W)(conj(R) * T)
            
            @test R̄ ≈ f_jvp[1]
            @test W̄.A ≈ f_jvp[2].A
            @test W̄.D ≈ f_jvp[2].D
            @test W̄.S ≈ f_jvp[2].S

ProjectTo is pushing the tangent W̄ into it's constituents, A, D and S. It is not the case that W̄ = ∂W.A * ∂W.D * ∂W.A' + ∂W.S. We want .

EDIT: Added in some of the variables that were implicit. Removed the call to the rrule which at the moment won't work. This is just testing the derivatives.

@AlexRobson
Copy link
Member Author

AlexRobson commented Jul 30, 2021

I'm not sure what is intended is possible in the present set-up. If we consider the pullback and the projection:

function _times_pullback(Ȳ::AbstractMatrix, A, B, proj)
    Ā = proj.A(dot(Ȳ, B)')
    B̄ = proj.B(A' * Ȳ)
    return (NoTangent(), Ā, B̄)
end

function (W::ProjectTo{T})(W̄) where {T<:WoodburyPDMat}
    Ā(W̄) = ProjectTo(W.A)((W̄ + W̄') * (W.A * W.D))
    D̄(W̄) = ProjectTo(W.D)(W.A' * (W̄) * W.A)
    S̄(W̄) = ProjectTo(W.S)(W̄)
    return Tangent{T}(; A = Ā(W̄), D = D̄(W̄), S = S̄(W̄))
end

We want the pullback to support Ȳ::Tangent{WoodburyPDMat. The snippet above shows that passing through W̄ (∂primal in that snippet). However upon projection, this will have the fields A,D,S with their associated pullbacks due to the Projection. Not W̄ .

Some options to consider:

  • We can't recover W̄ from the Projection (non of the entries are invertable, in general projections are likely to be lossy). Put another way, the structured differential here can't recover the natural differential (if this is the right language to use, idk). Which may imply that this projection isn't what we want.

  • We could extend the tangent representation to include the natural differential, however with some playing with the canonicalize this isn't possible as CRC assumes for any structured matrix the tangent fields match the primal fields.

  • Another option would be to overload rand_tangent as just hte matrix and pass that in. However, this means that FD will error because the size of the primal (which is a woodbury) will be different to that of the tangent due to the inconsistent representations.

  • Another option would be to densify the to_vec representation of the woodbury for the purposes of testing, and densify the primals, but that means we need to densify everywhere iiuc which largeyl defeats the purpose of having a lazy matrix.

Though I've probably misunderstood something somewhere.

@mzgubic
Copy link
Contributor

mzgubic commented Jul 30, 2021

I was thinking about the last option. I only think it is possible to go from Woodbury to a vector, but not back (as far as I know it is not possible to "factorise" an arbitrary dense matrix to a Woodbury?).

It might still solve our problems, because in some cases only going to a vector is required (and not going back). "Some cases" is I think the case where we vectorise the tangent of the primal output.

@AlexRobson
Copy link
Member Author

I've cleaned up this MR now - it should be ready to go.

As there were several discussion points previously, I'll give an overview:

  • Updates to use ChainRulesCore v1.

  • Inplements a ProjectTo for the Woodbury that takes a dx::AbstractMatrix and projects that pushback onto the constituent elements of the Woodbury. Essentially, as we know that W = A D A' + S, we know how Ā, D̄, S̄ can be constructed from W̄. So if W̄ is a Matrix, Project is used to produce the Tangent{WoodburyPDMat}(A = Ā, D = D̄, S = S̄). If W̄ is already a Tangent, Project just passes this through as it is in the correct subspace.

  • The rrules for Diagonal * Woodbury (and reverse), and the rrules for Real * Woodbury (and reverse) are implemented. These capture two different scenarios for the rrules. In the first case, abbreviated to Dmat * W this returns a dense Matrix in the primal, and thus the cotangent will be a Matrix. In the second case, abbreviated to R * W, this returns another Woodbury, and thus the cotangent will be of type Tangent{WoodburyPDMat}. So we have:

    • (D * W)::Matrix - This is covered by existing rrules as the cotangent will be a Matrix and we have ProjectTo specified.
    • (R * W)::WoodburyPDMat - This requires custom rrules because we need to work with a cotangent of type Tangent{WoodburyPDMat}.
  • We don't need to test (D * W) with a Tangent{WoodburyPDMat} cotangent because the primal type is a Matrix. I have added a (R * W) with a Matrix cotangent test but we can't use test_rrule for this directly, as commented.

  • The WoodburyLike thing is removed. In retrospect I don't particularly like the Woodbury being force-constructed from the constiutent gradients, but it does mean that the existing tests can just be used (i.e. the output of f_jvp does create a Woodbury, but it's unlikely to be a valid one).

Will take out of WIP.

@AlexRobson AlexRobson changed the title WIP: Add ChainRules Add ChainRules Aug 15, 2021
@AlexRobson AlexRobson requested a review from mzgubic August 15, 2021 15:48
ChainRulesCore.ProjectTo{T}(; fields...)
end

function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this does. It looks like we are trying to project an arbitrary matrix onto a Woodbury, which as I understand is not possible in case X is not positive semi-definite? (In that case we might explicitly check for PSD and throw an error?)

If all is well and I just don't understand something, could we add a few lines of comments that explain what is going on and maybe link a reference for the math?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to project an arbitrary matrix onto a Woodbury

Nope - we are projecting an arbitrary matrix onto a Woodbury tangent. I.e. instead of passing around W̄
, we are passing around Ā, D̄, S̄

The matrix here can be anything. Project here is encoding the rules on how the individual components of the woodbury, i.e. from the definition W = ADA' + S -> dW = ADdA' + AdDA' + dS => Ā = 2ADW̄, D̄=AW̄A', S̄ = W̄ with their appropriate projectiosn down to e.g. Diagonal. X̄ here is W̄.

Ideally, these would only be in the constructor for the Woodbury, however, due to the way FD is set-up, it seems we always have to project down to these components. THe woodbury is a lazy low-rank matrix and because of this any woodbury primal has number of elements 2*(A1 + A2) where (A1, A2) is the size of the matrix A. And the tangent has to be the same size (i.e. number of elements), i.e. their to_vec representation must equal in length

This is why when testing with a Matrix cotangent I couldn't use CRTU and had to do a manual call to FD and densify the primal here

If need be, I believe we can only ever return the natural differential Matrix (i.e. W̄ ) as the outut of any abitrary rrule with a Woodbury argument as is done here but as I say I believe that FD will not like this in some scenarios.

The constraints imposed here are, as I understand them are that the primal output and output cotangent need to be of the same length, and that the input args and input adjoints are also all of the same length to each other (where length is the to_vec representation).

I'll add comments, but first I want to make sure that that makes sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I am actually not sure what we should do here, I've opened JuliaDiff/ChainRulesCore.jl#442
with the PositiveReal example you have showed me.

Copy link
Member Author

@AlexRobson AlexRobson Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading that thread.

The use case here aside, Is projecting onto a Tangent not an intended use of ProjectTo?
I'm sort of getting that impression from reading these threads, and it is not one that I had realised, if that is the case. I now note that nothing in the projection.jl in CRC does map to Tangent for some arbitrary struct.

I had thought the point of Project was to take a differential and map it appropriately onto the correct differential representation, in this case it being a Tangent{WoodburyPDMat}.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can map to Tangent because Tangent is a differential type. The differences is that

ProjectTo(w::Tangent{Woodbury})

is fine, but

ProjectTo(w::Woodbury)

is not, and we are doing the second here. (Despite the fact that the projection actually returns the Tangent{Woodbury}, which is inconsistent since we said we are projecting onto a Woodbury)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this relate to the proposal in JuliaDiff/ChainRulesCore.jl#449?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Will thought that it is possible to find the inverse of Woodbury densification (for the subspace of dense matrices that can come from a Woodbury) but I may have misinterpreted that.

This seems to be a weird edge case @willtebbutt. If I understand correctly, in this case we want the structural rather than the natural differential inside the pullback written in the rule (rather than the premise in the 449 which expects AD to prefer structural tangents, and rule writers to prefer natural tangents)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposal in 449 would basically replace ProjectTo / change its semantics to eat a natural tangent / Matrix and produce a structural tangent. You obtain the new thing by implementing the pullback for collect, and that does everything you need.

Re inverse of densification -- I found that hard to implement for WoodburyPDMat -- see here but the pullback for collect (pullback_of_destructure -- I wanted to give it a different name) is totally fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, what's the rationale for implementing ProjectTo? Is it to make it possible to do *(A::Matrix, B::WoodburyPDMat)? If so, would it make more sense / be easier to implement a specialised method for the primal (since the generic fallback will be horribly slow) that is differentiable, and @opt_out of the generic rrule to let AD do the work of deriving the rule?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, what's the rationale for implementing ProjectTo

Originally, I was using the ProjectTo here to project the natural tangent onto it's structural tangent - i.e. for a given W̄ (the natural tangent, as I understand the semantics), to it's structural (Ā, D̄, S̄), but it turned out this wasn't permitted under the semantics of ProjectTo - which what led to the 'allow projectto onto tangent' issue . I paused looking at this MR at that point, but saw your RFC being posted and pinged this thread on how these two related. As far as I understand the resolution to #442 in CRC it's added an identity, but I don't think that changes anything w.r.t using it as a tangent, so I still wouldn't know how to continue this.

If I understand your RFC, the way I was using the projectto here was more in line with what you were considering.

Comment on lines +34 to +38
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat}
Ā = @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S))
B̄ = @thunk proj.B(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * A, S = Ȳ.S * A))
return (NoTangent(), Ā, B̄)
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we do something similar to the forward pass? like

Suggested change
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat}
= @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S))
= @thunk proj.B(Tangent{WoodburyPDMat}(; A =.A, D =.D * A, S =.S * A))
return (NoTangent(), Ā, B̄)
end
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat}
df, db, da = _times_pullback(Ȳ, B, A, (A=proj.B, B=proj.A))
return df, db, da
end

(I don't think we have to worry about adjointing a Real)

)
end

@testset "*(Matrix-Woodbury)" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this rule anywhere, does this work because ChainRules are loaded? might be worth adding a comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think these are because Zygote is loaded.

I'll add ChainRules explicitely

@AlexRobson AlexRobson changed the title Add ChainRules WIP: Add ChainRules Aug 18, 2021
Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for jumping in late. Saw this linked in 449.


# Rule for Woodbury * Real.
# Ignoring Complex version for now.
function ChainRulesCore.rrule(::typeof(*), A::WoodburyPDMat, B::Real)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to hop in late here, but why do we need a rule for this, as opposed to just opting out?

In particular, our implementation of this function ought to be straightforwardly differentiable by Zygote et al:

*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
*(c::Real, a::WoodburyPDMat) = a * c

Is it not sufficient just to @opt_out?

Copy link
Member Author

@AlexRobson AlexRobson Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for *(H::Diagonal, S::WoodburyPDMat). If you remember we had issues with that and had to do WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D). This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.

I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?

Copy link
Member

@willtebbutt willtebbutt Sep 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for *(H::Diagonal, S::WoodburyPDMat). If you remember we had issues with that and had to do WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D). This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.

Interesting -- again, I'd suggest just opting out here 🤷 Or am I missing the point?

I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?

No. In simple cases like this there's no reason that Zygote can't achieve optimal performance. If you have control flow (either via a for-loop or an if-else-end block) Zygote will generally struggle, but otherwise you're definitely better off letting Zygote figure stuff out.

My experience has consistently been that if you have the choice between writing a rule and writing AD-friendly code, always pick the latter. The fact that opting-out is necessary is unfortunate, but it's the compromise we have for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

between writing a rule and writing AD-friendly code, always pick the latter.

Interesting - thanks. I hadn't considered this angle, and that workarounds like WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D) are advised over setting up a rule.

ChainRulesCore.ProjectTo{T}(; fields...)
end

function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposal in 449 would basically replace ProjectTo / change its semantics to eat a natural tangent / Matrix and produce a structural tangent. You obtain the new thing by implementing the pullback for collect, and that does everything you need.

Re inverse of densification -- I found that hard to implement for WoodburyPDMat -- see here but the pullback for collect (pullback_of_destructure -- I wanted to give it a different name) is totally fine.

ChainRulesCore.ProjectTo{T}(; fields...)
end

function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, what's the rationale for implementing ProjectTo? Is it to make it possible to do *(A::Matrix, B::WoodburyPDMat)? If so, would it make more sense / be easier to implement a specialised method for the primal (since the generic fallback will be horribly slow) that is differentiable, and @opt_out of the generic rrule to let AD do the work of deriving the rule?

@AlexRobson
Copy link
Member Author

I'm just going to close this MR. I had largely set-out to try and understand CR through implementing rules for the woodbury (and follow up with a resolution to the Woodbury * diagonal AD workaround that was neccessary) with the presumption that the rules would be more performant (i neever got to benchmarking these) but this seems the wrong view.

There may be some follow-ups (such as opt-outs and supporting CRC 1) but it seems best to do that with a clean slate.

@AlexRobson AlexRobson closed this Sep 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants