diff --git a/Project.toml b/Project.toml index c61ff68c2..58e84dbf2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.44" +version = "0.6.45" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" @@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "1.37" +ChainRules = "1.44.1" ChainRulesCore = "1.9" ChainRulesTestUtils = "1" DiffRules = "1.4" diff --git a/src/lib/array.jl b/src/lib/array.jl index 293801b21..420e3716f 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -104,27 +104,6 @@ end @adjoint reshape(xs, dims...) = reshape(xs, dims...), Δ -> (reshape(Δ, size(xs)),map(_->nothing,dims)...) -@adjoint function hvcat(rows::Tuple{Vararg{Int}}, xs::Number...) - hvcat(rows, xs...), ȳ -> (nothing, permutedims(ȳ)...) -end - -pull_block_vert(sz, Δ, A::Number) = Δ[sz] -pull_block_vert(sz, Δ, A::AbstractVector) = Δ[sz-length(A)+1:sz] -pull_block_vert(sz, Δ, A::AbstractMatrix) = Δ[sz-size(A, 1)+1:sz, :] -@adjoint function vcat(A::Union{AbstractVector, AbstractMatrix, Number}...) - sz = cumsum([size.(A, 1)...]) - return vcat(A...), Δ->(map(n->pull_block_vert(sz[n], Δ, A[n]), eachindex(A))...,) -end -@adjoint vcat(xs::Number...) = vcat(xs...), Δ -> (Δ...,) - -pull_block_horz(sz, Δ, A::AbstractVector) = Δ[:, sz] -pull_block_horz(sz, Δ, A::AbstractMatrix) = Δ[:, sz-size(A, 2)+1:sz] -@adjoint function hcat(A::Union{AbstractVector, AbstractMatrix}...) - sz = cumsum([size.(A, 2)...]) - return hcat(A...), Δ->(map(n->pull_block_horz(sz[n], Δ, A[n]), eachindex(A))...,) -end -@adjoint hcat(xs::Number...) = hcat(xs...), Δ -> (Δ...,) - @adjoint function repeat(xs; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs))) repeat(xs, inner = inner, outer = outer), function (Δ) Δ′ = zero(xs)