Skip to content

WIP: Add ChainRules #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 40 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
d2c7537
Add and test rrule for Woodbury
Jul 27, 2021
c586180
Add manifest for Zygote branch
Jul 27, 2021
10bc061
Tidy up tests. White space clear
Jul 27, 2021
13054b4
Remove ZYgote from deps
Jul 27, 2021
d149676
Switch to Zygote master now it supports cr1 until tagged
Jul 27, 2021
49f1ee6
Extend Distirbutions for compat reasons
Jul 27, 2021
89947f6
Update manifest to use zygote#master until tagged
Jul 27, 2021
e81dff8
Add Zygote again to resolve CI issues
Jul 27, 2021
5dce0e9
Generate manifest on 1.6
Jul 27, 2021
2e31400
Readd diagonal tests after rebase removal
Jul 28, 2021
18a4b99
Rebuild manifest on exact CI version...
Jul 28, 2021
5f91685
Compat Bounds
Jul 28, 2021
bbfe9b2
MR comments1
Jul 28, 2021
c79f177
Update src/chainrules.jl
AlexRobson Jul 28, 2021
49f6795
Update src/chainrules.jl
AlexRobson Jul 28, 2021
64e7780
Update test/chainrules.jl
AlexRobson Jul 28, 2021
b3ff0c2
Fix up pullback
Jul 28, 2021
ca45e00
passing Tangent{Woodbury} and type inference fixed
Jul 29, 2021
f6db1da
Use Functor for ProjectTo. Update test
Jul 30, 2021
ed3ee75
Remove irrelevant Diagonal testse
Jul 30, 2021
0edc740
Merge branch 'ar/chainrules' of https://github.com/invenia/PDMatsExtr…
Jul 30, 2021
b6e94f8
Remove merge mess
Jul 30, 2021
dde2b46
Rework chainrules tests. Add constructor rrule
Aug 15, 2021
73f6bbc
Remove ChainRules from test deps
Aug 15, 2021
251d183
White space deletion
Aug 15, 2021
31cf7e1
Remove Zygote from deps
Aug 15, 2021
a9900c8
Remove ChainRules from extras
Aug 15, 2021
f75ed2e
Remove ChainRules from compat
Aug 15, 2021
955b2c1
Refactor long line. Add comment
Aug 15, 2021
2e7babe
Add projections and thunks into _times_pullback
Aug 15, 2021
e1bd5f9
Tangent{T} T<:Woodbury -> just Tangent
Aug 15, 2021
9431ead
Add in extra test that test a Matrix input to tangent
Aug 15, 2021
8ba6f0d
Add Matrix tangent tests into a seperate test set
Aug 15, 2021
ff8d903
Update src/chainrules.jl
AlexRobson Aug 16, 2021
eb23510
Update src/chainrules.jl
AlexRobson Aug 16, 2021
9535406
Update test/chainrules.jl
AlexRobson Aug 16, 2021
a9cdd38
Add (and remove) comments
Aug 16, 2021
eb194c5
Readd CHainRules as a test dep
Aug 16, 2021
c7e98c1
ChainRules added explicitely to runtests
Aug 16, 2021
8ce1498
Add comment vefore test set
Aug 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
ChainRulesCore = "0.9.17, 0.10"
Distributions = "0.23, 0.24"
ChainRules = "1"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Distributions = "0.23, 0.24, 0.25"
FiniteDifferences = "0.11, 0.12"
PDMats = "0.9, 0.10, 0.11"
Zygote = "0.5.5"
Zygote = "0.6"
julia = "1"

[extras]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -27,4 +31,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]
test = ["ChainRules", "ChainRulesTestUtils", "Distributions", "FiniteDifferences", "Random", "SuiteSparse", "Test", "Zygote"]
1 change: 1 addition & 0 deletions src/PDMatsExtras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ export submat
include("psd_mat.jl")
include("woodbury_pd_mat.jl")
include("utils.jl")
include("chainrules.jl")

end
69 changes: 69 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
@non_differentiable validate_woodbury_arguments(A, D, S)

# Rule for Woodbury * Real.
# Ignoring Complex version for now.
function ChainRulesCore.rrule(::typeof(*), A::WoodburyPDMat, B::Real)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to hop in late here, but why do we need a rule for this, as opposed to just opting out?

In particular, our implementation of this function ought to be straightforwardly differentiable by Zygote et al:

*(a::WoodburyPDMat, c::Real) = WoodburyPDMat(a.A, a.D * c, a.S * c)
*(c::Real, a::WoodburyPDMat) = a * c

Is it not sufficient just to @opt_out?

Copy link
Member Author

@AlexRobson AlexRobson Sep 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for *(H::Diagonal, S::WoodburyPDMat). If you remember we had issues with that and had to do WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D). This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.

I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?

Copy link
Member

@willtebbutt willtebbutt Sep 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is probabaly true. AFAIK this works - here however I was using this to test where the primal returns a woodbury (and thus your output co-tangent is a tangent{woodbury}) as I wanted to use this as a basis to sort out how to work on the rule for *(H::Diagonal, S::WoodburyPDMat). If you remember we had issues with that and had to do WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D). This was actually the direction I wanted to take this MR, but possibly a lot has moved on since that issue.

Interesting -- again, I'd suggest just opting out here 🤷 Or am I missing the point?

I hadn't considered opt-outs... if the rule works are they generally not more efficient than zygote working it out?

No. In simple cases like this there's no reason that Zygote can't achieve optimal performance. If you have control flow (either via a for-loop or an if-else-end block) Zygote will generally struggle, but otherwise you're definitely better off letting Zygote figure stuff out.

My experience has consistently been that if you have the choice between writing a rule and writing AD-friendly code, always pick the latter. The fact that opting-out is necessary is unfortunate, but it's the compromise we have for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

between writing a rule and writing AD-friendly code, always pick the latter.

Interesting - thanks. I hadn't considered this angle, and that workarounds like WoodburyPDMat(H * S.A, S.D, H * S.S * H' + D) are advised over setting up a rule.

project_A = ProjectTo(A)
project_B = ProjectTo(B)
primal = A * B
times_pullback(ȳ) = _times_pullback(ȳ, primal, A, B, (;A=project_A, B=project_B))
return primal, times_pullback
end

function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
primal = A * B
times_pullback(ȳ) = _times_pullback(ȳ, primal, A, B, (;A=project_A, B=project_B))
return primal, times_pullback
end

_times_pullback(ȳ::AbstractThunk, primal, A, B, proj) = _times_pullback(unthunk(ȳ), primal, A, B, proj)
# If the cotangent is a Matrix we first need to project down, otherwise ignore
_times_pullback(Ȳ::AbstractMatrix, primal, A, B, proj) = _times_pullback(ProjectTo(primal)(Ȳ), A, B, proj)
_times_pullback(ȳ::Tangent, primal, A, B, proj) = _times_pullback(ȳ, A, B, proj)

function _times_pullback(Ȳ::Tangent, A::T, B::Real, proj) where {T<:WoodburyPDMat}
Ā = @thunk proj.A(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * B', S = Ȳ.S * B'))
B̄ = @thunk proj.B(dot(Ȳ.D, A.D) + dot(Ȳ.S, A.S))
return (NoTangent(), Ā, B̄)
end

function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat}
Ā = @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S))
B̄ = @thunk proj.B(Tangent{WoodburyPDMat}(; A = Ȳ.A, D = Ȳ.D * A, S = Ȳ.S * A))
return (NoTangent(), Ā, B̄)
end
Comment on lines +32 to +36
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we do something similar to the forward pass? like

Suggested change
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat}
= @thunk proj.A(dot(Ȳ.D, B.D) + dot(Ȳ.S, B.S))
= @thunk proj.B(Tangent{WoodburyPDMat}(; A =.A, D =.D * A, S =.S * A))
return (NoTangent(), Ā, B̄)
end
function _times_pullback(Ȳ::Tangent, A::Real, B::T, proj) where {T<:WoodburyPDMat}
df, db, da = _times_pullback(Ȳ, B, A, (A=proj.B, B=proj.A))
return df, db, da
end

(I don't think we have to worry about adjointing a Real)


# Composite pullbacks
function ChainRulesCore.rrule(
::Type{T},
A::AbstractMatrix,
D::Diagonal,
S::Diagonal,
) where {T<:WoodburyPDMat}
return WoodburyPDMat(A, D, S), X̄ -> WoodburyPDMat_pullback(X̄, A, D, S)
end
WoodburyPDMat_pullback(X̄::Tangent, A, D, S) = (NoTangent(), X̄.A, X̄.D, X̄.S)
WoodburyPDMat_pullback(X̄::AbstractThunk, A, D, S) = WoodburyPDMat_pullback(unthunk(X̄), A, D, S)

function ChainRulesCore.ProjectTo(W::T) where {T<:WoodburyPDMat}
fields = (A = W.A, D = W.D, S = W.S)
ChainRulesCore.ProjectTo{T}(; fields...)
end

#
# Project the differential onto the Tangent{WoodburyPDMat}.
# This essentially computes the pullbacks for the components of the Woodbury
# i.e. from the definition: W = ADA' + S
# dW = ADdA' + AdDA' + dS
# => Ā = 2ADW̄, D̄=AW̄A', S̄ = W̄
# More precise formulation available e.g. here:
# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
function (project::ProjectTo{T})(X̄::AbstractMatrix) where {T<:WoodburyPDMat}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what this does. It looks like we are trying to project an arbitrary matrix onto a Woodbury, which as I understand is not possible in case X is not positive semi-definite? (In that case we might explicitly check for PSD and throw an error?)

If all is well and I just don't understand something, could we add a few lines of comments that explain what is going on and maybe link a reference for the math?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to project an arbitrary matrix onto a Woodbury

Nope - we are projecting an arbitrary matrix onto a Woodbury tangent. I.e. instead of passing around W̄
, we are passing around Ā, D̄, S̄

The matrix here can be anything. Project here is encoding the rules on how the individual components of the woodbury, i.e. from the definition W = ADA' + S -> dW = ADdA' + AdDA' + dS => Ā = 2ADW̄, D̄=AW̄A', S̄ = W̄ with their appropriate projectiosn down to e.g. Diagonal. X̄ here is W̄.

Ideally, these would only be in the constructor for the Woodbury, however, due to the way FD is set-up, it seems we always have to project down to these components. THe woodbury is a lazy low-rank matrix and because of this any woodbury primal has number of elements 2*(A1 + A2) where (A1, A2) is the size of the matrix A. And the tangent has to be the same size (i.e. number of elements), i.e. their to_vec representation must equal in length

This is why when testing with a Matrix cotangent I couldn't use CRTU and had to do a manual call to FD and densify the primal here

If need be, I believe we can only ever return the natural differential Matrix (i.e. W̄ ) as the outut of any abitrary rrule with a Woodbury argument as is done here but as I say I believe that FD will not like this in some scenarios.

The constraints imposed here are, as I understand them are that the primal output and output cotangent need to be of the same length, and that the input args and input adjoints are also all of the same length to each other (where length is the to_vec representation).

I'll add comments, but first I want to make sure that that makes sense

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I am actually not sure what we should do here, I've opened JuliaDiff/ChainRulesCore.jl#442
with the PositiveReal example you have showed me.

Copy link
Member Author

@AlexRobson AlexRobson Aug 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reading that thread.

The use case here aside, Is projecting onto a Tangent not an intended use of ProjectTo?
I'm sort of getting that impression from reading these threads, and it is not one that I had realised, if that is the case. I now note that nothing in the projection.jl in CRC does map to Tangent for some arbitrary struct.

I had thought the point of Project was to take a differential and map it appropriately onto the correct differential representation, in this case it being a Tangent{WoodburyPDMat}.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can map to Tangent because Tangent is a differential type. The differences is that

ProjectTo(w::Tangent{Woodbury})

is fine, but

ProjectTo(w::Woodbury)

is not, and we are doing the second here. (Despite the fact that the projection actually returns the Tangent{Woodbury}, which is inconsistent since we said we are projecting onto a Woodbury)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this relate to the proposal in JuliaDiff/ChainRulesCore.jl#449?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Will thought that it is possible to find the inverse of Woodbury densification (for the subspace of dense matrices that can come from a Woodbury) but I may have misinterpreted that.

This seems to be a weird edge case @willtebbutt. If I understand correctly, in this case we want the structural rather than the natural differential inside the pullback written in the rule (rather than the premise in the 449 which expects AD to prefer structural tangents, and rule writers to prefer natural tangents)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The proposal in 449 would basically replace ProjectTo / change its semantics to eat a natural tangent / Matrix and produce a structural tangent. You obtain the new thing by implementing the pullback for collect, and that does everything you need.

Re inverse of densification -- I found that hard to implement for WoodburyPDMat -- see here but the pullback for collect (pullback_of_destructure -- I wanted to give it a different name) is totally fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, what's the rationale for implementing ProjectTo? Is it to make it possible to do *(A::Matrix, B::WoodburyPDMat)? If so, would it make more sense / be easier to implement a specialised method for the primal (since the generic fallback will be horribly slow) that is differentiable, and @opt_out of the generic rrule to let AD do the work of deriving the rule?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, what's the rationale for implementing ProjectTo

Originally, I was using the ProjectTo here to project the natural tangent onto it's structural tangent - i.e. for a given W̄ (the natural tangent, as I understand the semantics), to it's structural (Ā, D̄, S̄), but it turned out this wasn't permitted under the semantics of ProjectTo - which what led to the 'allow projectto onto tangent' issue . I paused looking at this MR at that point, but saw your RFC being posted and pinged this thread on how these two related. As far as I understand the resolution to #442 in CRC it's added an identity, but I don't think that changes anything w.r.t using it as a tangent, so I still wouldn't know how to continue this.

If I understand your RFC, the way I was using the projectto here was more in line with what you were considering.

Ā = ProjectTo(project.A)((X̄ + X̄') * (project.A * project.D))
D̄ = ProjectTo(project.D)(project.A' * (X̄) * project.A)
S̄ = ProjectTo(project.S)(X̄)
return Tangent{WoodburyPDMat}(; A = Ā, D = D̄, S = S̄)
end
(project::ProjectTo{T})(W::Tangent) where {T<:WoodburyPDMat} = W
2 changes: 0 additions & 2 deletions src/woodbury_pd_mat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ function validate_woodbury_arguments(A, D, S)
end
end

@non_differentiable validate_woodbury_arguments(A, D, S)

function LinearAlgebra.logdet(W::WoodburyPDMat)
C_S = cholesky(W.S)
B = C_S.U' \ (W.A * cholesky(W.D).U')
Expand Down
46 changes: 46 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
@testset "ChainRules" begin

A = randn(4, 2)
D = Diagonal(randn(2).^2 .+ 1)
S = Diagonal(randn(4).^2 .+ 1)

W = WoodburyPDMat(A, D, S)
R = 2.0
Dmat = Diagonal(rand(4,))

x = randn(size(A, 1))

@testset "Constructors" begin
test_rrule(WoodburyPDMat, W.A, W.D, W.S)
# This is a gradient, should be able to deal with negative elements (does not have to be PSD like Woodbury itself)
test_rrule(WoodburyPDMat, W.A, W.D, W.S;
output_tangent=Tangent{WoodburyPDMat}(;
A = rand(4,2), D = Diagonal(-1 * rand(2,)), S = Diagonal(-1 * rand(4,)))
)
end

# The rrules already in ChainRules are sufficient for these to work. We just test an example here.
@testset "*(Matrix-Woodbury)" begin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this rule anywhere, does this work because ChainRules are loaded? might be worth adding a comment

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think these are because Zygote is loaded.

I'll add ChainRules explicitely

test_rrule(*, Dmat, W)
test_rrule(*, W, Dmat)
test_rrule(*, rand(4,4), W)
end

@testset "*(Woodbury-Real)" begin
test_rrule(*, W, R)
test_rrule(*, R, W)

# We can't test test_rrule(*, R, W; output_tangent = rand(size(W)...)) i.e. with a Matrix because
# FD requires the primal and tangent to be the same size. However, we can just call FD directly and overload
# the primal computation to return a Matrix:
@testset "Matrix CoTangent" begin
res, pb = ChainRulesCore.rrule(*, R, W)
output_tangent = rand(size(W)...)
f_jvp = j′vp(ChainRulesTestUtils._fdm, x -> Matrix(*(x...)), output_tangent, (R, W))[1]
@test unthunk(pb(output_tangent)[3]).A ≈ f_jvp[2].A
@test unthunk(pb(output_tangent)[3]).D ≈ f_jvp[2].D
@test unthunk(pb(output_tangent)[3]).S ≈ f_jvp[2].S
@test unthunk(pb(output_tangent)[2]) ≈ f_jvp[1]
end
end
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using PDMatsExtras
using ChainRules
using ChainRulesCore
using ChainRulesTestUtils
using Distributions
using FiniteDifferences
using LinearAlgebra
Expand Down Expand Up @@ -33,6 +35,7 @@ const TEST_MATRICES = Dict(
include("test_ad.jl")

include("psd_mat.jl")
include("chainrules.jl")
include("woodbury_pd_mat.jl")
include("utils.jl")
end