Skip to content

Commit ca45e00

Browse files
author
Miha Zgubic
committed
passing Tangent{Woodbury} and type inference fixed
1 parent b3ff0c2 commit ca45e00

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

src/chainrules.jl

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
@non_differentiable validate_woodbury_arguments(A, D, S)
22

3+
function _times_pullback::AbstractMatrix, A, B, proj)
4+
#Ā = @thunk(proj[:A](dot(Ȳ, B)'))
5+
#B̄ = @thunk(proj[:B](A' * Ȳ))
6+
Ā = dot(Ȳ, B)'
7+
= A' * Ȳ
8+
return (NoTangent(), Ā, B̄)
9+
end
10+
_times_pullback(ȳ::AbstractThunk, A, B, proj) = _times_pullback(unthunk(ȳ), A, B, proj)
11+
function _times_pullback::Tangent{<:WoodburyPDMat}, A, B, proj)
12+
W = WoodburyPDMat(Ȳ.A, Ȳ.D, Ȳ.S)
13+
return _times_pullback(W, A, B, proj)
14+
end
15+
316
function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat)
417
project_A = ProjectTo(A)
518
project_B = ProjectTo(B)
6-
function times_pullback::AbstractMatrix)
7-
Ā = @thunk(project_A(dot(Ȳ, B)'))
8-
= @thunk(project_B(A' * Ȳ))
9-
return (NoTangent(), Ā, B̄)
10-
end
11-
12-
function times_pullback::Tangent{<:WoodburyPDMat})
13-
= dot(Ȳ.A *.D *.A' +.S, B)
14-
=.A * (A' *.D) *.A' + A' *.S
15-
return (
16-
NoTangent(),
17-
@thunk(project_A(Ā')),
18-
@thunk(project_B(B̄)),
19-
)
20-
end
19+
times_pullback(ȳ) = _times_pullback(ȳ, A, B, (;A=project_A, B=project_B))
2120
return A * B, times_pullback
2221
end
2322

0 commit comments

Comments
 (0)