diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 3414939853..9e6bdb53a0 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -14,7 +14,7 @@ end function ChainRulesCore.rrule(::typeof(batchnorm), g, b, x, running_mean, running_var, momentum; kw...) y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) function batchnorm_pullback(Δ) - grad = ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...) + grad = ∇batchnorm(g, b, x, unthunk(Δ), running_mean, running_var, momentum; kw...) (NoTangent(), grad..., NoTangent(), NoTangent(), NoTangent()) end y, batchnorm_pullback diff --git a/src/functor.jl b/src/functor.jl index bfa075a6b8..13adbe13ff 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -119,11 +119,11 @@ adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng() adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x function ChainRulesCore.rrule(::Type{Array}, x::CUDA.CuArray) - Array(x), d -> (NoTangent(), CUDA.cu(d),) + Array(x), dx -> (NoTangent(), CUDA.cu(unthunk(dx)),) end function ChainRulesCore.rrule(::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::CUDA.AbstractGPUArray) - adapt_storage(to, x), d -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), d),) + adapt_storage(to, x), dx -> (NoTangent(), NoTangent(), adapt_storage(FluxCUDAAdaptor(), unthunk(dx)),) end # CPU/GPU movement conveniences @@ -227,3 +227,4 @@ f64(m) = paramtype(Float64, m) # Functors for certain Julia data structures @functor Cholesky trainable(c::Cholesky) = () + diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 760933bb96..c3b89f33a7 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -9,7 +9,7 @@ multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N) function ChainRulesCore.rrule(::typeof(multigate), x::AbstractArray, h, c) function multigate_pullback(dy) dx = map!(zero, similar(x, float(eltype(x)), axes(x)), x) - foreach(multigate(dx, h, c), dy) do dxᵢ, dyᵢ + foreach(multigate(dx, h, c), unthunk(dy)) do dxᵢ, dyᵢ dyᵢ isa AbstractZero && return @. dxᵢ += dyᵢ end