|
1 | 1 | @non_differentiable validate_woodbury_arguments(A, D, S)
|
2 | 2 |
|
| 3 | +function _times_pullback(Ȳ::AbstractMatrix, A, B, proj) |
| 4 | + #Ā = @thunk(proj[:A](dot(Ȳ, B)')) |
| 5 | + #B̄ = @thunk(proj[:B](A' * Ȳ)) |
| 6 | + Ā = dot(Ȳ, B)' |
| 7 | + B̄ = A' * Ȳ |
| 8 | + return (NoTangent(), Ā, B̄) |
| 9 | +end |
| 10 | +_times_pullback(ȳ::AbstractThunk, A, B, proj) = _times_pullback(unthunk(ȳ), A, B, proj) |
| 11 | +function _times_pullback(Ȳ::Tangent{<:WoodburyPDMat}, A, B, proj) |
| 12 | + W = WoodburyPDMat(Ȳ.A, Ȳ.D, Ȳ.S) |
| 13 | + return _times_pullback(W, A, B, proj) |
| 14 | +end |
| 15 | + |
3 | 16 | function ChainRulesCore.rrule(::typeof(*), A::Real, B::WoodburyPDMat)
|
4 | 17 | project_A = ProjectTo(A)
|
5 | 18 | project_B = ProjectTo(B)
|
6 |
| - function times_pullback(Ȳ::AbstractMatrix) |
7 |
| - Ā = @thunk(project_A(dot(Ȳ, B)')) |
8 |
| - B̄ = @thunk(project_B(A' * Ȳ)) |
9 |
| - return (NoTangent(), Ā, B̄) |
10 |
| - end |
11 |
| - |
12 |
| - function times_pullback(Ȳ::Tangent{<:WoodburyPDMat}) |
13 |
| - Ā = dot(Ȳ.A * Ȳ.D * Ȳ.A' + Ȳ.S, B) |
14 |
| - B̄ = Ȳ.A * (A' * Ȳ.D) * Ȳ.A' + A' * Ȳ.S |
15 |
| - return ( |
16 |
| - NoTangent(), |
17 |
| - @thunk(project_A(Ā')), |
18 |
| - @thunk(project_B(B̄)), |
19 |
| - ) |
20 |
| - end |
| 19 | + times_pullback(ȳ) = _times_pullback(ȳ, A, B, (;A=project_A, B=project_B)) |
21 | 20 | return A * B, times_pullback
|
22 | 21 | end
|
23 | 22 |
|
|
0 commit comments