Skip to content

Commit e06c8fd

Browse files
committed
Support Enzyme
1 parent 7ff795b commit e06c8fd

File tree

7 files changed

+72
-49
lines changed

7 files changed

+72
-49
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TaylorDiff"
22
uuid = "b36ab563-344f-407b-a36a-4f200bebf99c"
33
authors = ["Songchen Tan <[email protected]>"]
4-
version = "0.2.4"
4+
version = "0.2.5"
55

66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"

benchmark/groups/pinn.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
using Lux, Random, Zygote
1+
using Lux, Zygote
22

33
const input = 2
44
const hidden = 16
55

6-
model = Chain(Dense(input => hidden, exp),
7-
Dense(hidden => hidden, exp),
6+
model = Chain(Dense(input => hidden, Lux.relu),
7+
Dense(hidden => hidden, Lux.relu),
88
Dense(hidden => 1),
99
first)
1010

src/chainrules.jl

+10
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,13 @@ for f in (
8686
@eval @opt_out rrule(::typeof($f), x::$tlhs, y::$trhs)
8787
end
8888
end
89+
90+
# Multi-argument functions
91+
92+
@opt_out frule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar)
93+
@opt_out rrule(::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar)
94+
95+
@opt_out frule(
96+
::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)
97+
@opt_out rrule(
98+
::typeof(*), x::TaylorScalar, y::TaylorScalar, z::TaylorScalar, more::TaylorScalar...)

test/Project.toml

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
[deps]
2+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
3+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
4+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
25
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
38
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
49
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
11+
[compat]
12+
Enzyme = "0.13"

test/downstream.jl

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using LinearAlgebra
2+
import DifferentiationInterface
3+
using DifferentiationInterface: AutoZygote, AutoEnzyme
4+
import Zygote, Enzyme
5+
using FiniteDiff: finite_difference_derivative
6+
7+
DI = DifferentiationInterface
8+
backend = AutoZygote()
9+
# backend = AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const)
10+
11+
@testset "Zygote-over-TaylorDiff on same variable" begin
12+
# Scalar functions
13+
some_number = 0.7
14+
some_numbers = [0.3, 0.4, 0.1]
15+
for f in (exp, log, sqrt, sin, asin, sinh, asinh, x -> x^3)
16+
@test DI.derivative(x -> derivative(f, x, 2), backend, some_number)
17+
derivative(f, some_number, 3)
18+
@test DI.jacobian(x -> derivative.(f, x, 2), backend, some_numbers)
19+
diagm(derivative.(f, some_numbers, 3))
20+
end
21+
22+
# Vector functions
23+
g(x) = x[1] * x[1] + x[2] * x[2]
24+
@test DI.gradient(x -> derivative(g, x, [1.0, 0.0], 1), backend, [1.0, 2.0])
25+
[2.0, 0.0]
26+
27+
# Matrix functions
28+
some_matrix = [0.7 0.1; 0.4 0.2]
29+
f(x) = sum(exp.(x), dims = 1)
30+
dfdx1(x) = derivative(f, x, [1.0, 0.0], 1)
31+
dfdx2(x) = derivative(f, x, [0.0, 1.0], 1)
32+
res(x) = sum(dfdx1(x) .+ 2 * dfdx2(x))
33+
grad = DI.gradient(res, backend, some_matrix)
34+
@test grad [1 0; 0 2] * exp.(some_matrix)
35+
end
36+
37+
@testset "Zygote-over-TaylorDiff on different variable" begin
38+
linear_model(x, p, b) = exp.(b + p * x + b)[1]
39+
loss_taylor(x, p, b, v) = derivative(x -> linear_model(x, p, b), x, v, 1)
40+
ε = cbrt(eps(Float64))
41+
loss_finite(x, p, b, v) = (linear_model(x + ε * v, p, b) -
42+
linear_model(x - ε * v, p, b)) / (2 * ε)
43+
let some_x = [0.58, 0.36], some_v = [0.23, 0.11], some_p = [0.49 0.96], some_b = [0.88]
44+
@test DI.gradient(
45+
p -> loss_taylor(some_x, p, some_b, some_v), backend, some_p)
46+
DI.gradient(
47+
p -> loss_finite(some_x, p, some_b, some_v), backend, some_p)
48+
end
49+
end

test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ using Test
33

44
include("primitive.jl")
55
include("derivative.jl")
6-
include("zygote.jl")
6+
include("downstream.jl")
77
# include("lux.jl")

test/zygote.jl

-43
This file was deleted.

0 commit comments

Comments
 (0)