diff --git a/src/ApproximateGPs.jl b/src/ApproximateGPs.jl index 029be317..97a9d29b 100644 --- a/src/ApproximateGPs.jl +++ b/src/ApproximateGPs.jl @@ -23,4 +23,10 @@ include("deprecations.jl") include("TestUtils.jl") +import ChainRulesCore: ProjectTo, Tangent +using PDMats: ScalMat +ProjectTo(x::T) where {T<:ScalMat} = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value)) +(pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value)) +(pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value)) + end