55
55
56
56
ProjectTo (:: T ) where {T <: TaylorScalar } = ProjectTo {T} ()
57
57
(p:: ProjectTo{T} )(x:: T ) where {T <: TaylorScalar } = x
58
- ProjectTo (x:: AbstractArray{T} ) where {T <: TaylorScalar } = ProjectTo {AbstractArray} (; element= ProjectTo (zero (T)), axes= axes (x))
58
+ function ProjectTo (x:: AbstractArray{T} ) where {T <: TaylorScalar }
59
+ ProjectTo {AbstractArray} (; element = ProjectTo (zero (T)), axes = axes (x))
60
+ end
59
61
(p:: ProjectTo{AbstractArray{T}} )(x:: AbstractArray{T} ) where {T <: TaylorScalar } = x
60
62
accum_sum (xs:: AbstractArray{T} ; dims = :) where {T <: TaylorScalar } = sum (xs, dims = dims)
61
63
62
- TaylorNumeric{T<: TaylorScalar } = Union{T, AbstractArray{<: T }}
64
+ TaylorNumeric{T <: TaylorScalar } = Union{T, AbstractArray{<: T }}
63
65
64
- @adjoint broadcasted (:: typeof (+ ), xs:: Union{Numeric, TaylorNumeric} ...) = broadcast (+ , xs... ), ȳ -> (nothing , map (x -> unbroadcast (x, ȳ), xs)... )
66
+ @adjoint function broadcasted (:: typeof (+ ), xs:: Union{Numeric, TaylorNumeric} ...)
67
+ broadcast (+ , xs... ), ȳ -> (nothing , map (x -> unbroadcast (x, ȳ), xs)... )
68
+ end
65
69
66
- struct TaylorOneElement{T,N,I, A} <: AbstractArray{T,N}
70
+ struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
67
71
val:: T
68
72
ind:: I
69
73
axes:: A
70
- TaylorOneElement (val:: T , ind:: I , axes:: A ) where {T<: TaylorScalar , I<: NTuple{N,Int} , A<: NTuple{N,AbstractUnitRange} } where {N} = new {T,N,I,A} (val, ind, axes)
74
+ function TaylorOneElement (val:: T , ind:: I ,
75
+ axes:: A ) where {T <: TaylorScalar , I <: NTuple{N, Int} ,
76
+ A <: NTuple{N, AbstractUnitRange} } where {N}
77
+ new {T, N, I, A} (val, ind, axes)
78
+ end
71
79
end
72
80
73
81
Base. size (A:: TaylorOneElement ) = map (length, A. axes)
74
82
Base. axes (A:: TaylorOneElement ) = A. axes
75
- Base. getindex (A:: TaylorOneElement{T,N} , i:: Vararg{Int,N} ) where {T,N} = ifelse (i== A. ind, A. val, zero (T))
83
+ function Base. getindex (A:: TaylorOneElement{T, N} , i:: Vararg{Int, N} ) where {T, N}
84
+ ifelse (i == A. ind, A. val, zero (T))
85
+ end
76
86
77
- ∇getindex (x:: AbstractArray{T, N} , inds) where {T <: TaylorScalar , N} = dy -> begin
78
- dx = TaylorOneElement (dy, inds, axes (x))
79
- return (_project (x, dx), map (_-> nothing , inds)... )
87
+ function ∇getindex (x:: AbstractArray{T, N} , inds) where {T <: TaylorScalar , N}
88
+ dy -> begin
89
+ dx = TaylorOneElement (dy, inds, axes (x))
90
+ return (_project (x, dx), map (_ -> nothing , inds)... )
91
+ end
80
92
end
81
93
82
94
@generated function mul_adjoint (Ω:: TaylorScalar{T, N} , x:: TaylorScalar{T, N} ) where {T, N}
@@ -93,12 +105,14 @@ rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x)
93
105
function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar )
94
106
function times_pullback2 (Ω̇)
95
107
ΔΩ = unthunk (Ω̇)
96
- return (NoTangent (), ProjectTo (x)(mul_adjoint (ΔΩ, y)), ProjectTo (y)(mul_adjoint (ΔΩ, x)))
108
+ return (NoTangent (), ProjectTo (x)(mul_adjoint (ΔΩ, y)),
109
+ ProjectTo (y)(mul_adjoint (ΔΩ, x)))
97
110
end
98
111
return x * y, times_pullback2
99
112
end
100
113
101
- function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar , z:: TaylorScalar , more:: TaylorScalar... )
114
+ function rrule (:: typeof (* ), x:: TaylorScalar , y:: TaylorScalar , z:: TaylorScalar ,
115
+ more:: TaylorScalar... )
102
116
Ω2, back2 = rrule (* , x, y)
103
117
Ω3, back3 = rrule (* , Ω2, z)
104
118
Ω4, back4 = rrule (* , Ω3, more... )
0 commit comments