Skip to content

Commit 7b2d2d6

Browse files
committed
Fix format and bump version
1 parent 1e02ac0 commit 7b2d2d6

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TaylorDiff"
22
uuid = "b36ab563-344f-407b-a36a-4f200bebf99c"
33
authors = ["Songchen Tan <[email protected]>"]
4-
version = "0.1.3"
4+
version = "0.2.0"
55

66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"

src/chainrules.jl

+25-11
Original file line numberDiff line numberDiff line change
@@ -55,28 +55,40 @@ end
5555

5656
ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}()
5757
(p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x
58-
ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} = ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
58+
function ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar}
59+
ProjectTo{AbstractArray}(; element = ProjectTo(zero(T)), axes = axes(x))
60+
end
5961
(p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x
6062
accum_sum(xs::AbstractArray{T}; dims = :) where {T <: TaylorScalar} = sum(xs, dims = dims)
6163

62-
TaylorNumeric{T<:TaylorScalar} = Union{T, AbstractArray{<:T}}
64+
TaylorNumeric{T <: TaylorScalar} = Union{T, AbstractArray{<:T}}
6365

64-
@adjoint broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...) = broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
66+
@adjoint function broadcasted(::typeof(+), xs::Union{Numeric, TaylorNumeric}...)
67+
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
68+
end
6569

66-
struct TaylorOneElement{T,N,I,A} <: AbstractArray{T,N}
70+
struct TaylorOneElement{T, N, I, A} <: AbstractArray{T, N}
6771
val::T
6872
ind::I
6973
axes::A
70-
TaylorOneElement(val::T, ind::I, axes::A) where {T<:TaylorScalar, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
74+
function TaylorOneElement(val::T, ind::I,
75+
axes::A) where {T <: TaylorScalar, I <: NTuple{N, Int},
76+
A <: NTuple{N, AbstractUnitRange}} where {N}
77+
new{T, N, I, A}(val, ind, axes)
78+
end
7179
end
7280

7381
Base.size(A::TaylorOneElement) = map(length, A.axes)
7482
Base.axes(A::TaylorOneElement) = A.axes
75-
Base.getindex(A::TaylorOneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
83+
function Base.getindex(A::TaylorOneElement{T, N}, i::Vararg{Int, N}) where {T, N}
84+
ifelse(i == A.ind, A.val, zero(T))
85+
end
7686

77-
∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N} = dy -> begin
78-
dx = TaylorOneElement(dy, inds, axes(x))
79-
return (_project(x, dx), map(_->nothing, inds)...)
87+
function ∇getindex(x::AbstractArray{T, N}, inds) where {T <: TaylorScalar, N}
88+
dy -> begin
89+
dx = TaylorOneElement(dy, inds, axes(x))
90+
return (_project(x, dx), map(_ -> nothing, inds)...)
91+
end
8092
end
8193

8294
@generated function mul_adjoint::TaylorScalar{T, N}, x::TaylorScalar{T, N}) where {T, N}
@@ -93,12 +105,14 @@ rrule(::typeof(*), x::TaylorScalar) = rrule(identity, x)
93105
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar)
94106
function times_pullback2(Ω̇)
95107
ΔΩ = unthunk(Ω̇)
96-
return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)), ProjectTo(y)(mul_adjoint(ΔΩ, x)))
108+
return (NoTangent(), ProjectTo(x)(mul_adjoint(ΔΩ, y)),
109+
ProjectTo(y)(mul_adjoint(ΔΩ, x)))
97110
end
98111
return x * y, times_pullback2
99112
end
100113

101-
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)
114+
function rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar,
115+
more::TaylorScalar...)
102116
Ω2, back2 = rrule(*, x, y)
103117
Ω3, back3 = rrule(*, Ω2, z)
104118
Ω4, back4 = rrule(*, Ω3, more...)

src/scalar.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ end
7676
# Number-like convention (I patched them after removing <: Number)
7777

7878
convert(::Type{TaylorScalar{T, N}}, x::TaylorScalar{T, N}) where {T, N} = x
79-
convert(::Type{TaylorScalar{T, N}}, x::S) where {T, S, N} = TaylorScalar{T, N}(convert(T, x))
79+
function convert(::Type{TaylorScalar{T, N}}, x::S) where {T, S, N}
80+
TaylorScalar{T, N}(convert(T, x))
81+
end
8082
for op in (:+, :-, :*, :/)
8183
@eval @inline $op(a::TaylorScalar, b::Number) = $op(promote(a, b)...)
8284
@eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...)

0 commit comments

Comments
 (0)