Skip to content

Commit 112184d

Browse files
authored
Merge pull request #230 from JuliaDiff/ox/compeq
fix == on incomplete Composites
2 parents 622195c + d345c4b commit 112184d

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.14"
3+
version = "0.9.15"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/differentials/composite.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,15 @@ function Composite{P}(d::Dict) where {P<:Dict}
4545
return Composite{P, typeof(d)}(d)
4646
end
4747

48-
Base.:(==)(a::Composite, b::Composite) = backing(a) == backing(b)
48+
function Base.:(==)(a::Composite{P, T}, b::Composite{P, T}) where {P, T}
49+
return backing(a) == backing(b)
50+
end
51+
function Base.:(==)(a::Composite{P}, b::Composite{P}) where {P, T}
52+
return canonicalize(a) == canonicalize(b)
53+
end
54+
Base.:(==)(a::Composite{P}, b::Composite{Q}) where {P, Q} = false
55+
56+
Base.hash(a::Composite, h::UInt) = Base.hash(backing(canonicalize(a)), h)
4957

5058
function Base.show(io::IO, comp::Composite{P}) where P
5159
print(io, "Composite{")
@@ -155,6 +163,9 @@ end
155163
# Tuple composites are always in their canonical form
156164
canonicalize(comp::Composite{<:Tuple, <:Tuple}) = comp
157165

166+
# Dict composite are always in their canonical form.
167+
canonicalize(comp::Composite{<:Any, <:AbstractDict}) = comp
168+
158169
"""
159170
_zeroed_backing(P)
160171
@@ -206,7 +217,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an, bn}
206217
# https://github.com/JuliaLang/julia/blob/592748adb25301a45bd6edef3ac0a93eed069852/base/namedtuple.jl#L220-L231
207218
if @generated
208219
names = Base.merge_names(an, bn)
209-
220+
210221
vals = map(names) do field
211222
a_field = :(getproperty(a, $(QuoteNode(field))))
212223
b_field = :(getproperty(b, $(QuoteNode(field))))

test/differentials/composite.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,22 @@ end
2929
@test convert(Dict, Composite{Dict}(Dict(4 => 3))) == Dict(4 => 3)
3030
end
3131

32+
@testset "==" begin
33+
@test Composite{Foo}(x=0.1, y=2.5) == Composite{Foo}(x=0.1, y=2.5)
34+
@test Composite{Foo}(x=0.1, y=2.5) == Composite{Foo}(y=2.5, x=0.1)
35+
@test Composite{Foo}(y=2.5, x=Zero()) == Composite{Foo}(y=2.5)
36+
37+
@test Composite{Tuple{Float64,}}(2.0) == Composite{Tuple{Float64,}}(2.0)
38+
@test Composite{Dict}(Dict(4 => 3)) == Composite{Dict}(Dict(4 => 3))
39+
40+
end
41+
42+
@testset "hash" begin
43+
@test hash(Composite{Foo}(x=0.1, y=2.5)) == hash(Composite{Foo}(y=2.5, x=0.1))
44+
@test hash(Composite{Foo}(y=2.5, x=Zero())) == hash(Composite{Foo}(y=2.5))
45+
end
46+
47+
3248
@testset "indexing, iterating, and properties" begin
3349
@test keys(Composite{Foo}(x=2.5)) == (:x,)
3450
@test propertynames(Composite{Foo}(x=2.5)) == (:x,)
@@ -78,7 +94,15 @@ end
7894

7995
@testset "canonicalize" begin
8096
# Testing iterate via collect
81-
@test collect(Composite{Tuple{Float64,}}(2.0)) == [2.0]
97+
@test ==(
98+
canonicalize(Composite{Tuple{Float64,}}(2.0)),
99+
Composite{Tuple{Float64,}}(2.0)
100+
)
101+
102+
@test ==(
103+
canonicalize(Composite{Dict}(Dict(4 => 3))),
104+
Composite{Dict}(Dict(4 => 3)),
105+
)
82106

83107
# For structure it needs to match order and Zero() fill to match primal
84108
CFoo = Composite{Foo}
@@ -219,7 +243,7 @@ end
219243

220244
@testset "show" begin
221245
@test repr(Composite{Foo}(x=1,)) == "Composite{Foo}(x = 1,)"
222-
# check for exact regex match not occurence( `^...$`)
246+
# check for exact regex match not occurence( `^...$`)
223247
# and allowing optional whitespace (`\s?`)
224248
@test occursin(
225249
r"^Composite{Tuple{Int64,\s?Int64}}\(1,\s?2\)$",

0 commit comments

Comments
 (0)