Skip to content

Commit 8c8e87f

Browse files
committed
No longer <: Number
1 parent bd53fab commit 8c8e87f

File tree

3 files changed

+19
-29
lines changed

3 files changed

+19
-29
lines changed

src/chainrules.jl

+7-26
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,6 @@ function contract(a::TaylorScalar{T, N}, b::TaylorScalar{S, N}) where {T, S, N}
55
mapreduce(*, +, value(a), value(b))
66
end
77

8-
NONLINEAR_UNARY_FUNCTIONS = Function[exp, exp2, exp10, expm1,
9-
log, log2, log10, log1p,
10-
inv, sqrt, cbrt,
11-
sin, cos, tan, cot, sec, csc,
12-
asin, acos, atan, acot, asec, acsc,
13-
sinh, cosh, tanh, coth, sech, csch,
14-
asinh, acosh, atanh, acoth, asech, acsch]
15-
16-
for func in NONLINEAR_UNARY_FUNCTIONS
17-
@eval @opt_out rrule(::typeof($func), ::TaylorScalar)
18-
end
19-
20-
NONLINEAR_BINARY_FUNCTIONS = Function[*, /, ^]
21-
22-
for func in NONLINEAR_BINARY_FUNCTIONS
23-
@eval @opt_out rrule(::typeof($func), ::TaylorScalar, ::TaylorScalar)
24-
@eval @opt_out rrule(::typeof($func), ::TaylorScalar, ::Number)
25-
@eval @opt_out rrule(::typeof($func), ::Number, ::TaylorScalar)
26-
end
27-
28-
# Other special cases
29-
30-
@opt_out rrule(::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar, ::Val{p}) where {p}
31-
@opt_out rrule(::RuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::TaylorScalar,
32-
::Val{p}) where {p}
33-
348
function rrule(::Type{TaylorScalar{T, N}}, v::NTuple{N, T}) where {N, T <: Number}
359
taylor_scalar_pullback(t̄) = NoTangent(), value(t̄)
3610
return TaylorScalar(v), taylor_scalar_pullback
@@ -75,3 +49,10 @@ end
7549
end
7650

7751
(project::ProjectTo{T})(dx::TaylorScalar{T, N}) where {N, T <: Number} = primal(dx)
52+
53+
# Not-a-number patches
54+
55+
ProjectTo(::T) where {T <: TaylorScalar} = ProjectTo{T}()
56+
(p::ProjectTo{T})(x::T) where {T <: TaylorScalar} = x
57+
ProjectTo(x::AbstractArray{T}) where {T <: TaylorScalar} = ProjectTo{AbstractArray}(; element=ProjectTo(zero(T)), axes=axes(x))
58+
(p::ProjectTo{AbstractArray{T}})(x::AbstractArray{T}) where {T <: TaylorScalar} = x

src/primitive.jl

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import Base: abs, abs2
32
import Base: exp, exp2, exp10, expm1, log, log2, log10, log1p, inv, sqrt, cbrt
43
import Base: sin, cos, tan, cot, sec, csc, sinh, cosh, tanh, coth, sech, csch

src/scalar.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import Base: zero, one, adjoint, conj
1+
import Base: zero, one, adjoint, conj, transpose
2+
import Base: +, -, *, /
23
import Base: convert, promote_rule
34

45
export TaylorScalar
@@ -12,7 +13,7 @@ Representation of Taylor polynomials.
1213
1314
- `value::NTuple{N, T}`: i-th element of this stores the (i-1)-th derivative
1415
"""
15-
struct TaylorScalar{T <: Number, N} <: Number
16+
struct TaylorScalar{T <: Number, N}
1617
value::NTuple{N, T}
1718
end
1819

@@ -69,3 +70,12 @@ function promote_rule(::Type{TaylorScalar{T, N}},
6970
::Type{S}) where {T <: Number, S <: Number, N}
7071
TaylorScalar{promote_type(T, S), N}
7172
end
73+
74+
# Number-like convention (I patched them after removing <: Number)
75+
76+
convert(::Type{TaylorScalar{T, N}}, x::Number) where {T, N} = TaylorScalar{T, N}(x)
77+
for op in (:+, :-, :*, :/)
78+
@eval @inline $op(a::TaylorScalar, b::Number) = $op(promote(a, b)...)
79+
@eval @inline $op(a::Number, b::TaylorScalar) = $op(promote(a, b)...)
80+
end
81+
transpose(t::TaylorScalar) = t

0 commit comments

Comments
 (0)