From ce767257427704e3ab036fa271d06a31ac1d5081 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Mar 2022 21:00:29 +0100 Subject: [PATCH 01/15] Make zeros hard in scalar rules --- src/rule_definition_tools.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index a77b16059..34416dea9 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -295,9 +295,11 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) # Apply `muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. (∂s_1, Δs_1), _∂s_Δs_tail = Iterators.peel(zip(_∂s, Δs)) - init_expr = :($∂s_1 * $Δs_1) + # zero gradients are treated as hard zeros. This avoids propagation of NaNs when + # partials are non-finite + init_expr = :(ifelse(iszero($Δs_1), zero($∂s_1), $∂s_1) * $Δs_1) summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd($∂s_i, $Δs_i, $ex)) + :(muladd(ifelse(iszero($Δs_i), zero($∂s_i), $∂s_i), $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) end From d83965c88a0dcaa75256814b491ae56eb22a5a39 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Mar 2022 21:00:47 +0100 Subject: [PATCH 02/15] Test hard zeros in scalar rules --- test/rule_definition_tools.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 5a177566d..259e8b00d 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -256,6 +256,29 @@ end @test (NoTangent(), 0.0 - 1.0im) === rrule(make_imaginary, 2.0im)[2](1.0) end + @testset "@scalar_rule strong zero (co)tangents" begin + suminv(x, y) = inv(x) + inv(y) + @scalar_rule suminv(x, y) (-(inv(x)^2), -(inv(y)^2)) + + @test frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0) === (Inf, -Inf) + @test frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) + @test frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) + + @test frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0) === (Inf, -Inf) + @test frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0) === (Inf, -1.0) + @test frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0) === (Inf, -1.0) + + @test rrule(suminv, 0.0, 1.0)[2](1.0) === (NoTangent(), -Inf, -1.0) + @test rrule(suminv, 0.0, 1.0)[2](ZeroTangent()) === + (NoTangent(), ZeroTangent(), ZeroTangent()) + @test rrule(suminv, 0.0, 1.0)[2](0.0) === (NoTangent(), 0.0, 0.0) + + @test rrule(suminv, 1.0, 0.0)[2](1.0) === (NoTangent(), -1.0, -Inf) + @test rrule(suminv, 1.0, 0.0)[2](ZeroTangent()) === + (NoTangent(), ZeroTangent(), ZeroTangent()) + @test rrule(suminv, 1.0, 0.0)[2](0.0) === (NoTangent(), 0.0, 0.0) + end + @testset "Regression tests against #276 and #265" begin # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/265 From fcc151ffedc852963ddfb6af6787b8be53d8041e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 13 Mar 2022 21:01:12 +0100 Subject: [PATCH 03/15] Increment minor version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6cc189db0..9ba191f10 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.13.0" +version = "1.14.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 6eb7457eaad37c5538ce8b3f5386d6cd426f3d38 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 13 Oct 2022 16:29:31 +0200 Subject: [PATCH 04/15] Apply suggestions from code review Co-authored-by: David Widmann Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rule_definition_tools.jl | 4 ++-- test/rule_definition_tools.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 34416dea9..949953c62 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -297,9 +297,9 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) (∂s_1, Δs_1), _∂s_Δs_tail = Iterators.peel(zip(_∂s, Δs)) # zero gradients are treated as hard zeros. This avoids propagation of NaNs when # partials are non-finite - init_expr = :(ifelse(iszero($Δs_1), zero($∂s_1), $∂s_1) * $Δs_1) + init_expr = :((iszero($Δs_1) ? zero($∂s_1) : $∂s_1) * $Δs_1) summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd(ifelse(iszero($Δs_i), zero($∂s_i), $∂s_i), $Δs_i, $ex)) + :(muladd((iszero($Δs_i) ? zero($∂s_i) : $∂s_i), $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 259e8b00d..0bd26e19a 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -262,7 +262,7 @@ end @test frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0) === (Inf, -Inf) @test frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) - @test frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) + @test frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) @test frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0) === (Inf, -Inf) @test frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0) === (Inf, -1.0) From 721877f3586a638c94ef5428e78c088fe4e59180 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 15 Oct 2022 11:01:46 +0200 Subject: [PATCH 05/15] Make zero for NoTangent return NoTangent --- src/tangent_types/abstract_zero.jl | 3 +++ test/tangent_types/abstract_zero.jl | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 986fc9854..f32c94809 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -87,3 +87,6 @@ arguments. ``` """ struct NoTangent <: AbstractZero end + +Base.zero(::NoTangent) = NoTangent() +Base.zero(::Type{NoTangent}) = NoTangent() diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index e3d8642e4..6a1af49e5 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -42,9 +42,7 @@ end @test broadcastable(z) isa Ref{ZeroTangent} @test zero(@thunk(3)) === z - @test zero(NoTangent()) === z @test zero(ZeroTangent) === z - @test zero(NoTangent) === z @test zero(Tangent{Tuple{Int,Int}}((1, 2))) === z for f in (transpose, adjoint, conj) @test f(z) === z @@ -94,6 +92,8 @@ @testset "NoTangent" begin dne = NoTangent() + @test zero(dne) === NoTangent() + @test zero(NoTangent) === NoTangent() @test dne + dne == dne @test dne + 1 == 1 @test 1 + dne == 1 From 47c4009ada5ed0158996b1f6819a813697ef58a1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 15 Oct 2022 11:02:06 +0200 Subject: [PATCH 06/15] Implement zero for NotImplemented --- src/tangent_types/notimplemented.jl | 12 ++---------- test/tangent_types/notimplemented.jl | 7 +++++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index 7016acd60..1620e96b1 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -48,16 +48,8 @@ Base.:/(x::AbstractZero, ::NotImplemented) = x Base.:/(x::NotImplemented, ::AbstractThunk) = throw(NotImplementedException(x)) Base.:/(::AbstractThunk, x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) -function Base.zero(::Type{<:NotImplemented}) - return throw( - NotImplementedException( - @not_implemented( - "`zero` is not defined for missing tangents of type `NotImplemented`" - ) - ), - ) -end +Base.zero(::NotImplemented) = ZeroTangent() +Base.zero(::Type{<:NotImplemented}) = ZeroTangent() Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index e113475c1..670d564e4 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -9,6 +9,11 @@ x = rand() thunk = @thunk(x^2) + # zero + @test @inferred(zero(ni)) === ZeroTangent() + @test @inferred(zero(typeof(ni2))) === ZeroTangent() + @test !iszero(ni) + # conjugate @test conj(ni) === ni @@ -58,8 +63,6 @@ @test_throws E a / ni end @test_throws E ni / ni2 - @test_throws E zero(ni) - @test_throws E zero(typeof(ni)) @test_throws E iterate(ni) @test_throws E iterate(ni, nothing) @test_throws E adjoint(ni) From b54eba3ddc415cd6a7175616f2739b50eb0ef595 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 15 Oct 2022 11:20:33 +0200 Subject: [PATCH 07/15] Revert "Implement zero for NotImplemented" This reverts commit 47c4009ada5ed0158996b1f6819a813697ef58a1. --- src/tangent_types/notimplemented.jl | 12 ++++++++++-- test/tangent_types/notimplemented.jl | 7 ++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/tangent_types/notimplemented.jl b/src/tangent_types/notimplemented.jl index 1620e96b1..7016acd60 100644 --- a/src/tangent_types/notimplemented.jl +++ b/src/tangent_types/notimplemented.jl @@ -48,8 +48,16 @@ Base.:/(x::AbstractZero, ::NotImplemented) = x Base.:/(x::NotImplemented, ::AbstractThunk) = throw(NotImplementedException(x)) Base.:/(::AbstractThunk, x::NotImplemented) = throw(NotImplementedException(x)) -Base.zero(::NotImplemented) = ZeroTangent() -Base.zero(::Type{<:NotImplemented}) = ZeroTangent() +Base.zero(x::NotImplemented) = throw(NotImplementedException(x)) +function Base.zero(::Type{<:NotImplemented}) + return throw( + NotImplementedException( + @not_implemented( + "`zero` is not defined for missing tangents of type `NotImplemented`" + ) + ), + ) +end Base.iterate(x::NotImplemented) = throw(NotImplementedException(x)) Base.iterate(x::NotImplemented, ::Any) = throw(NotImplementedException(x)) diff --git a/test/tangent_types/notimplemented.jl b/test/tangent_types/notimplemented.jl index 670d564e4..e113475c1 100644 --- a/test/tangent_types/notimplemented.jl +++ b/test/tangent_types/notimplemented.jl @@ -9,11 +9,6 @@ x = rand() thunk = @thunk(x^2) - # zero - @test @inferred(zero(ni)) === ZeroTangent() - @test @inferred(zero(typeof(ni2))) === ZeroTangent() - @test !iszero(ni) - # conjugate @test conj(ni) === ni @@ -63,6 +58,8 @@ @test_throws E a / ni end @test_throws E ni / ni2 + @test_throws E zero(ni) + @test_throws E zero(typeof(ni)) @test_throws E iterate(ni) @test_throws E iterate(ni, nothing) @test_throws E adjoint(ni) From dc51058e16c84defd5ae7e154f5e1d223e732053 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 15 Oct 2022 12:02:13 +0200 Subject: [PATCH 08/15] Increment minor version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 91a7b526e..20984db45 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.6" +version = "1.16.0" [deps] Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" From 443140123610644b563fcbf538303f0d19579427 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 15 Oct 2022 22:34:50 +0200 Subject: [PATCH 09/15] Apply suggestions from code review Co-authored-by: David Widmann --- test/rule_definition_tools.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 0bd26e19a..c23b62db5 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -260,23 +260,29 @@ end suminv(x, y) = inv(x) + inv(y) @scalar_rule suminv(x, y) (-(inv(x)^2), -(inv(y)^2)) - @test frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0) === (Inf, -Inf) - @test frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) - @test frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0) === (Inf, -1.0) - - @test frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0) === (Inf, -Inf) - @test frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0) === (Inf, -1.0) - @test frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0) === (Inf, -1.0) - - @test rrule(suminv, 0.0, 1.0)[2](1.0) === (NoTangent(), -Inf, -1.0) - @test rrule(suminv, 0.0, 1.0)[2](ZeroTangent()) === + @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0)) === (Inf, -Inf) + @test @inferred(frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0)) === (Inf, -1.0) + @test @inferred(frule((NoTangent(), NoTangent(), 1.0), suminv, 0.0, 1.0)) === (Inf, -1.0) + @test @inferred(frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0)) === (Inf, -1.0) + + @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0)) === (Inf, -Inf) + @test @inferred(frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0)) === (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, NoTangent()), suminv, 1.0, 0.0)) === (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0)) === (Inf, -1.0) + + @test @inferred(rrule(suminv, 0.0, 1.0)[2](1.0)) === (NoTangent(), -Inf, -1.0) + @test @inferred(rrule(suminv, 0.0, 1.0)[2](ZeroTangent())) === (NoTangent(), ZeroTangent(), ZeroTangent()) - @test rrule(suminv, 0.0, 1.0)[2](0.0) === (NoTangent(), 0.0, 0.0) + @test @inferred(rrule(suminv, 0.0, 1.0)[2](NoTangent())) === + (NoTangent(), NoTangent(), NoTangent()) + @test @inferred(rrule(suminv, 0.0, 1.0)[2](0.0)) === (NoTangent(), 0.0, 0.0) - @test rrule(suminv, 1.0, 0.0)[2](1.0) === (NoTangent(), -1.0, -Inf) - @test rrule(suminv, 1.0, 0.0)[2](ZeroTangent()) === + @test @inferred(rrule(suminv, 1.0, 0.0)[2](1.0)) === (NoTangent(), -1.0, -Inf) + @test @inferred(rrule(suminv, 1.0, 0.0)[2](ZeroTangent())) === (NoTangent(), ZeroTangent(), ZeroTangent()) - @test rrule(suminv, 1.0, 0.0)[2](0.0) === (NoTangent(), 0.0, 0.0) + @test @inferred(rrule(suminv, 1.0, 0.0)[2](NoTangent())) === + (NoTangent(), NoTangent(), NoTangent()) + @test @inferred(rrule(suminv, 1.0, 0.0)[2](0.0)) === (NoTangent(), 0.0, 0.0) end @testset "Regression tests against #276 and #265" begin From 9983b71f5d44419f115b03cd56972307e29b32f3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 15 Oct 2022 22:58:45 +0200 Subject: [PATCH 10/15] Run formatter --- test/rule_definition_tools.jl | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index c23b62db5..01dca3ded 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -260,15 +260,23 @@ end suminv(x, y) = inv(x) + inv(y) @scalar_rule suminv(x, y) (-(inv(x)^2), -(inv(y)^2)) - @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0)) === (Inf, -Inf) - @test @inferred(frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0)) === (Inf, -1.0) - @test @inferred(frule((NoTangent(), NoTangent(), 1.0), suminv, 0.0, 1.0)) === (Inf, -1.0) - @test @inferred(frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0)) === (Inf, -1.0) - - @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0)) === (Inf, -Inf) - @test @inferred(frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0)) === (Inf, -1.0) - @test @inferred(frule((NoTangent(), 1.0, NoTangent()), suminv, 1.0, 0.0)) === (Inf, -1.0) - @test @inferred(frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0)) === (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 0.0, 1.0)) === + (Inf, -Inf) + @test @inferred(frule((NoTangent(), ZeroTangent(), 1.0), suminv, 0.0, 1.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), NoTangent(), 1.0), suminv, 0.0, 1.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), 0.0, 1.0), suminv, 0.0, 1.0)) === + (Inf, -1.0) + + @test @inferred(frule((NoTangent(), 1.0, 1.0), suminv, 1.0, 0.0)) === + (Inf, -Inf) + @test @inferred(frule((NoTangent(), 1.0, ZeroTangent()), suminv, 1.0, 0.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, NoTangent()), suminv, 1.0, 0.0)) === + (Inf, -1.0) + @test @inferred(frule((NoTangent(), 1.0, 0.0), suminv, 1.0, 0.0)) === + (Inf, -1.0) @test @inferred(rrule(suminv, 0.0, 1.0)[2](1.0)) === (NoTangent(), -Inf, -1.0) @test @inferred(rrule(suminv, 0.0, 1.0)[2](ZeroTangent())) === From 8852aa6ffa84554739638bb2607de1e68c88f3c5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 29 Dec 2022 12:30:23 +0100 Subject: [PATCH 11/15] Add tests for not covered cases --- test/rule_definition_tools.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 01dca3ded..e85a6277a 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -291,6 +291,23 @@ end @test @inferred(rrule(suminv, 1.0, 0.0)[2](NoTangent())) === (NoTangent(), NoTangent(), NoTangent()) @test @inferred(rrule(suminv, 1.0, 0.0)[2](0.0)) === (NoTangent(), 0.0, 0.0) + + # cases not covered + t = @thunk(0.0) + @inferred(frule((NoTangent(), t, 1.0), suminv, 0.0, 1.0)) + @inferred(frule((NoTangent(), 1.0, t), suminv, 1.0, 0.0)) + @inferred(rrule(suminv, 0.0, 1.0)[2](t)) + @inferred(rrule(suminv, 1.0, 0.0)[2](t)) + @test_broken rrule(suminv, 0.0, 1.0)[2](t) == (NoTangent(), 0.0, 0.0) + @test_broken rrule(suminv, 1.0, 0.0)[2](t) == (NoTangent(), 0.0, 0.0) + @test_broken frule((NoTangent(), t, 1.0), suminv, 0.0, 1.0) == (Inf, -1.0) + @test_broken frule((NoTangent(), 1.0, t), suminv, 1.0, 0.0) == (Inf, -1.0) + + ni = @not_implemented("not implemented!") + @test_broken rrule(suminv, 0.0, 1.0)[2](ni) == (NoTangent(), 0.0, 0.0) + @test_broken rrule(suminv, 1.0, 0.0)[2](ni) == (NoTangent(), 0.0, 0.0) + @test_broken frule((NoTangent(), ni, 1.0), suminv, 0.0, 1.0) == (Inf, -1.0) + @test_broken frule((NoTangent(), 1.0, ni), suminv, 1.0, 0.0) == (Inf, -1.0) end @testset "Regression tests against #276 and #265" begin From 06f8b67d2c092ae0d5455cb960eea35db025848d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 4 Feb 2023 20:49:53 +0100 Subject: [PATCH 12/15] Add strong_mul and strong_muladd --- src/rule_definition_tools.jl | 25 ++++++++++++++++++++++--- test/rule_definition_tools.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 5ecf90faa..22c82eb47 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -292,14 +292,14 @@ function propagation_expr(Δs, ∂s, _conj=false, proj=identity) end end - # Apply `muladd` iteratively. + # Apply `strong_muladd` iteratively. # Explicit multiplication is only performed for the first pair of partial and gradient. (∂s_1, Δs_1), _∂s_Δs_tail = Iterators.peel(zip(_∂s, Δs)) # zero gradients are treated as hard zeros. This avoids propagation of NaNs when # partials are non-finite - init_expr = :((iszero($Δs_1) ? zero($∂s_1) : $∂s_1) * $Δs_1) + init_expr = :(strong_mul($∂s_1, $Δs_1)) summed_∂_mul_Δs = foldl(_∂s_Δs_tail; init=init_expr) do ex, (∂s_i, Δs_i) - :(muladd((iszero($Δs_i) ? zero($∂s_i) : $∂s_i), $Δs_i, $ex)) + :(strong_muladd($∂s_i, $Δs_i, $ex)) end return :($proj($summed_∂_mul_Δs)) end @@ -617,3 +617,22 @@ function _constrain_and_name(arg::Expr, _) return error("malformed arguments: $arg") end _constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type + +""" + strong_mul(x, y) + +Multiply `x` and `y`. If `iszero(y)`, treat `y` as a hard zero even for non-finite `x`. +""" +strong_mul(x, y) = ifelse(iszero(y), zero(x), x) * y + +""" + strong_muladd(x, y, z) + +Multiply `x` and `y` and add to `z`. If `iszero(y)`, treat `y` as a hard zero even for +non-finite `x`. +""" +strong_muladd(x, y, z) = muladd(ifelse(iszero(y), zero(x), x), y, z) + +# slightly faster for BigFloats +strong_mul(x::BigFloat, y::BigFloat) = (iszero(y) ? zero(x) : x) * y +strong_muladd(x::BigFloat, y::BigFloat, z) = muladd((iszero(y) ? zero(x) : x), y, z) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index e85a6277a..8bdb9439d 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -224,6 +224,34 @@ end end end + @testset "strong_mul" begin + @testset for T in (Float32, Float64, BigFloat), S in (Float32, Float64, BigFloat) + x = randn(T) + y = randn(S) + @test ChainRulesCore.strong_mul(x, y) == x * y + @test ChainRulesCore.strong_mul(x, zero(y)) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(Inf, x), zero(y)) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(-Inf, x), zero(y)) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(NaN, x), zero(y)) == zero(x * y) + end + end + + @testset "strong_muladd" begin + @testset for T in (Float32, Float64, BigFloat), + S in (Float32, Float64, BigFloat), + R in (Float32, Float64, BigFloat) + + x = randn(T) + y = randn(S) + z = randn(R) + @test ChainRulesCore.strong_muladd(x, y, z) == muladd(x, y, z) + @test ChainRulesCore.strong_muladd(x, zero(y), z) == z + @test ChainRulesCore.strong_muladd(oftype(Inf, x), zero(y), z) == z + @test ChainRulesCore.strong_muladd(oftype(-Inf, x), zero(y), z) == z + @test ChainRulesCore.strong_muladd(oftype(NaN, x), zero(y), z) == z + end + end + @testset "@scalar_rule" begin @testset "@scalar_rule with multiple output" begin simo(x) = (x, 2x) From e66a0972e6e9a9f945f51ced24bb27eaa7310a30 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 4 Feb 2023 21:03:35 +0100 Subject: [PATCH 13/15] Fix random generation pre-v1.9 --- test/rule_definition_tools.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 8bdb9439d..3ec75e361 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -226,8 +226,8 @@ end @testset "strong_mul" begin @testset for T in (Float32, Float64, BigFloat), S in (Float32, Float64, BigFloat) - x = randn(T) - y = randn(S) + x = T === BigFloat ? big(randn()) : randn(T) + y = S === BigFloat ? big(randn()) : randn(S) @test ChainRulesCore.strong_mul(x, y) == x * y @test ChainRulesCore.strong_mul(x, zero(y)) == zero(x * y) @test ChainRulesCore.strong_mul(oftype(Inf, x), zero(y)) == zero(x * y) @@ -241,9 +241,9 @@ end S in (Float32, Float64, BigFloat), R in (Float32, Float64, BigFloat) - x = randn(T) - y = randn(S) - z = randn(R) + x = T === BigFloat ? big(randn()) : randn(T) + y = S === BigFloat ? big(randn()) : randn(S) + z = R === BigFloat ? big(randn()) : randn(R) @test ChainRulesCore.strong_muladd(x, y, z) == muladd(x, y, z) @test ChainRulesCore.strong_muladd(x, zero(y), z) == z @test ChainRulesCore.strong_muladd(oftype(Inf, x), zero(y), z) == z From b7bdee1fa6261821801685e23e441062abee42e7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 5 Feb 2023 11:14:52 +0100 Subject: [PATCH 14/15] Add methods for NotImplemented --- src/rule_definition_tools.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 22c82eb47..14d083b18 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -636,3 +636,7 @@ strong_muladd(x, y, z) = muladd(ifelse(iszero(y), zero(x), x), y, z) # slightly faster for BigFloats strong_mul(x::BigFloat, y::BigFloat) = (iszero(y) ? zero(x) : x) * y strong_muladd(x::BigFloat, y::BigFloat, z) = muladd((iszero(y) ? zero(x) : x), y, z) + +# avoid raising errors for NotImplemented +strong_mul(x::NotImplemented, y) = (iszero(y) ? zero(x) : x) * y +strong_muladd(x::NotImplemented, y, z) = muladd((iszero(y) ? zero(x) : x), y, z) From 3cdd1087bd44cada817746c77ef136ce7a200bf0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 5 Feb 2023 11:15:00 +0100 Subject: [PATCH 15/15] Update tests --- test/rule_definition_tools.jl | 67 ++++++++++++++++++++++++----------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 3ec75e361..764e65115 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -225,30 +225,57 @@ end end @testset "strong_mul" begin - @testset for T in (Float32, Float64, BigFloat), S in (Float32, Float64, BigFloat) - x = T === BigFloat ? big(randn()) : randn(T) - y = S === BigFloat ? big(randn()) : randn(S) - @test ChainRulesCore.strong_mul(x, y) == x * y - @test ChainRulesCore.strong_mul(x, zero(y)) == zero(x * y) - @test ChainRulesCore.strong_mul(oftype(Inf, x), zero(y)) == zero(x * y) - @test ChainRulesCore.strong_mul(oftype(-Inf, x), zero(y)) == zero(x * y) - @test ChainRulesCore.strong_mul(oftype(NaN, x), zero(y)) == zero(x * y) + ni = @not_implemented("not implemented!") + xvals = ( + 5, + randn(Float32), + randn(Float64), + randn(ComplexF64), + big(randn()), + ZeroTangent(), + ) + yvals = (3, randn(Float32), randn(Float64), randn(ComplexF64), big(randn())) + yzerovals = (0, 0.0f0, 0.0, 0.0im, big(0.0), ZeroTangent(), NoTangent()) + @testset for x in xvals + x === ni || @testset for y in yvals + @test @inferred(ChainRulesCore.strong_mul(x, y)) == x * y + end + @testset for y in yzerovals + @test @inferred(ChainRulesCore.strong_mul(x, y)) == zero(x * y) + if x isa AbstractFloat + @test ChainRulesCore.strong_mul(oftype(x, Inf), y) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(x, -Inf), y) == zero(x * y) + @test ChainRulesCore.strong_mul(oftype(x, NaN), y) == zero(x * y) + end + end end end @testset "strong_muladd" begin - @testset for T in (Float32, Float64, BigFloat), - S in (Float32, Float64, BigFloat), - R in (Float32, Float64, BigFloat) - - x = T === BigFloat ? big(randn()) : randn(T) - y = S === BigFloat ? big(randn()) : randn(S) - z = R === BigFloat ? big(randn()) : randn(R) - @test ChainRulesCore.strong_muladd(x, y, z) == muladd(x, y, z) - @test ChainRulesCore.strong_muladd(x, zero(y), z) == z - @test ChainRulesCore.strong_muladd(oftype(Inf, x), zero(y), z) == z - @test ChainRulesCore.strong_muladd(oftype(-Inf, x), zero(y), z) == z - @test ChainRulesCore.strong_muladd(oftype(NaN, x), zero(y), z) == z + ni = @not_implemented("not implemented!") + xvals = ( + 5, + randn(Float32), + randn(Float64), + randn(ComplexF64), + big(randn()), + ZeroTangent(), + ) + yvals = (3, randn(Float32), randn(Float64), randn(ComplexF64), big(randn())) + zvals = (7, randn(Float32), randn(Float64), randn(ComplexF64), big(randn())) + yzerovals = (0, 0.0f0, 0.0, 0.0 * im, big(0.0), ZeroTangent(), NoTangent()) + @testset for x in xvals, z in zvals + x === ni || @testset for y in yvals + @test @inferred(ChainRulesCore.strong_muladd(x, y, z)) == muladd(x, y, z) + end + @testset for y in yzerovals + @test @inferred(ChainRulesCore.strong_muladd(x, y, z)) == z + if x isa AbstractFloat + @test ChainRulesCore.strong_muladd(oftype(x, Inf), y, z) == z + @test ChainRulesCore.strong_muladd(oftype(x, -Inf), y, z) == z + @test ChainRulesCore.strong_muladd(oftype(x, NaN), y, z) == z + end + end end end