Skip to content

Commit c5dbe03

Browse files
authored
normalize not just vectors (#602)
* normalise arrays not just vectors * versions
1 parent 2cc27e2 commit c5dbe03

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.28.4"
3+
version = "1.29.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/norm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ end
257257
##### `normalize`
258258
#####
259259

260-
function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
260+
function rrule(::typeof(normalize), x::AbstractArray{<:Number}, p::Real)
261261
nrm, inner_pullback = rrule(norm, x, p)
262262
Ty = typeof(first(x) / nrm)
263263
y = copyto!(similar(x, Ty), x)
@@ -273,7 +273,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
273273
return y, normalize_pullback
274274
end
275275

276-
function rrule(::typeof(normalize), x::AbstractVector{<:Number})
276+
function rrule(::typeof(normalize), x::AbstractArray{<:Number})
277277
nrm = LinearAlgebra.norm2(x)
278278
Ty = typeof(first(x) / nrm)
279279
y = copyto!(similar(x, Ty), x)

test/rulesets/LinearAlgebra/norm.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,21 @@ end
182182
# ===================================
183183

184184
@testset "normalize" begin
185-
@testset "x::Vector{$T}" for T in (Float64, ComplexF64)
185+
@testset "x::Array{$T}" for T in (Float64, ComplexF64)
186186
x = randn(T, 3)
187187
test_rrule(normalize, x)
188188
@test rrule(normalize, x)[2](ZeroTangent()) === (NoTangent(), ZeroTangent())
189+
190+
test_rrule(normalize, rand(T, 3, 4))
191+
test_rrule(normalize, adjoint(rand(T, 5)))
189192
end
190-
@testset "x::Vector{$T}, p=$p" for T in (Float64, ComplexF64),
191-
p in (1.0, 2.0, -Inf, Inf, 2.5) # skip p=0, since FD is unstable
193+
@testset "x::Array{$T}, p=$p" for T in (Float64, ComplexF64), p in (1.0, 2.0, -Inf, Inf, 2.5)
194+
# skip p=0, since FD is unstable
192195
x = randn(T, 3)
193196
test_rrule(normalize, x, p)
194197
@test rrule(normalize, x, p)[2](ZeroTangent()) === (NoTangent(), ZeroTangent(), ZeroTangent())
198+
199+
test_rrule(normalize, rand(T, 3, 4), p)
200+
test_rrule(normalize, adjoint(rand(T, 5)), p)
195201
end
196202
end

0 commit comments

Comments
 (0)