Skip to content

Commit c69589a

Browse files
authored
feat: add ConstructionBaseExt to allow Setfield and Functors support (#94)
* feat: add ConstructionBaseExt to allow Setfield and Functors support * refactor: bypass incorrect original type ordering in constructor
1 parent c5c2b8c commit c69589a

File tree

4 files changed

+70
-4
lines changed

4 files changed

+70
-4
lines changed

Project.toml

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
11
name = "ADTypes"
22
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = ["Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors"]
4-
version = "1.9.1"
4+
version = "1.10.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
89
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
910

1011
[weakdeps]
1112
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
13+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1214
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1315

1416
[extensions]
1517
ADTypesChainRulesCoreExt = "ChainRulesCore"
18+
ADTypesConstructionBaseExt = "ConstructionBase"
1619
ADTypesEnzymeCoreExt = "EnzymeCore"
1720

1821
[compat]
1922
ChainRulesCore = "1.0.2"
23+
ConstructionBase = "1.5"
2024
EnzymeCore = "0.5.3,0.6,0.7,0.8"
2125
julia = "1.6"
2226

@@ -25,7 +29,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2529
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2630
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2731
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
32+
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2833
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2934

3035
[targets]
31-
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Test"]
36+
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Setfield", "Test"]

ext/ADTypesConstructionBaseExt.jl

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module ADTypesConstructionBaseExt
2+
3+
using ADTypes: AutoEnzyme, AutoForwardDiff, AutoPolyesterForwardDiff
4+
using ConstructionBase: ConstructionBase
5+
6+
struct InternalAutoEnzymeReconstructor{A} end
7+
8+
InternalAutoEnzymeReconstructor{A}(mode::M) where {M, A} = AutoEnzyme{M, A}(mode)
9+
10+
function ConstructionBase.constructorof(::Type{<:AutoEnzyme{M, A}}) where {M, A}
11+
return InternalAutoEnzymeReconstructor{A}
12+
end
13+
14+
function ConstructionBase.constructorof(::Type{<:AutoForwardDiff{chunksize}}) where {chunksize}
15+
return AutoForwardDiff{chunksize}
16+
end
17+
18+
function ConstructionBase.constructorof(::Type{<:AutoPolyesterForwardDiff{chunksize}}) where {chunksize}
19+
return AutoPolyesterForwardDiff{chunksize}
20+
end
21+
22+
end

src/dense.jl

+8-2
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,10 @@ struct AutoForwardDiff{chunksize, T} <: AbstractADType
181181
tag::T
182182
end
183183

184+
AutoForwardDiff{chunksize}(tag::T) where {chunksize, T} = AutoForwardDiff{chunksize, T}(tag)
185+
184186
function AutoForwardDiff(; chunksize = nothing, tag = nothing)
185-
AutoForwardDiff{chunksize, typeof(tag)}(tag)
187+
return AutoForwardDiff{chunksize}(tag)
186188
end
187189

188190
mode(::AutoForwardDiff) = ForwardMode()
@@ -271,8 +273,12 @@ struct AutoPolyesterForwardDiff{chunksize, T} <: AbstractADType
271273
tag::T
272274
end
273275

276+
function AutoPolyesterForwardDiff{chunksize}(tag::T) where {chunksize, T}
277+
return AutoPolyesterForwardDiff{chunksize, T}(tag)
278+
end
279+
274280
function AutoPolyesterForwardDiff(; chunksize = nothing, tag = nothing)
275-
AutoPolyesterForwardDiff{chunksize, typeof(tag)}(tag)
281+
return AutoPolyesterForwardDiff{chunksize}(tag)
276282
end
277283

278284
mode(::AutoPolyesterForwardDiff) = ForwardMode()

test/misc.jl

+33
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,36 @@ for backend in [
6161
]
6262
println(backend)
6363
end
64+
65+
using Setfield
66+
67+
@testset "Setfield compatibility" begin
68+
ad = AutoEnzyme()
69+
@test ad.mode === nothing
70+
@set! ad.mode = EnzymeCore.Reverse
71+
@test ad.mode isa EnzymeCore.ReverseMode
72+
73+
struct CustomTestTag end
74+
75+
ad = AutoForwardDiff()
76+
@test ad.tag === nothing
77+
@set! ad.tag = CustomTestTag()
78+
@test ad.tag isa CustomTestTag
79+
80+
ad = AutoForwardDiff(; chunksize = 10)
81+
@test ad.tag === nothing
82+
@set! ad.tag = CustomTestTag()
83+
@test ad.tag isa CustomTestTag
84+
@test ad isa AutoForwardDiff{10}
85+
86+
ad = AutoPolyesterForwardDiff()
87+
@test ad.tag === nothing
88+
@set! ad.tag = CustomTestTag()
89+
@test ad.tag isa CustomTestTag
90+
91+
ad = AutoPolyesterForwardDiff(; chunksize = 10)
92+
@test ad.tag === nothing
93+
@set! ad.tag = CustomTestTag()
94+
@test ad.tag isa CustomTestTag
95+
@test ad isa AutoPolyesterForwardDiff{10}
96+
end

0 commit comments

Comments
 (0)