Skip to content

Commit 1597bcc

Browse files
rrule for stack (#681)
* rrule for stack * bump version * extend rrule to muldim containers * hope you don't mind me committing these * import stack in tests * cleanup * Apply3 suggestions Co-authored-by: Michael Abbott <[email protected]>
1 parent a9a84ba commit 1597bcc

File tree

5 files changed

+62
-2
lines changed

5 files changed

+62
-2
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.44.7"
3+
version = "1.45.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -20,7 +20,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2020
Adapt = "3.4.0"
2121
ChainRulesCore = "1.15.3"
2222
ChainRulesTestUtils = "1.5"
23-
Compat = "3.42.0, 4"
23+
Compat = "3.46, 4.2"
2424
FiniteDifferences = "0.12.20"
2525
GPUArraysCore = "0.1.0"
2626
IrrationalConstants = "0.1.1"

src/ChainRules.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import ChainRulesCore: rrule, frule
2222
# Experimental:
2323
using ChainRulesCore: derivatives_given_output
2424

25+
using Compat: stack
26+
2527
# numbers that we know commute under multiplication
2628
const CommutativeMulNumber = Union{Real,Complex}
2729

src/rulesets/Base/array.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,30 @@ function _extrema_dims(x, dims)
610610
end
611611
return y, extrema_pullback_dims
612612
end
613+
614+
#####
615+
##### `stack`
616+
#####
617+
618+
function frule((_, ẋ), ::typeof(stack), x; dims::Union{Integer, Colon} = :)
619+
return stack(x; dims), stack(ẋ; dims)
620+
end
621+
622+
# Other iterable X also allowed, maybe this should be wider?
623+
function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = :)
624+
Y = stack(X; dims)
625+
sdims = if dims isa Colon
626+
N = ndims(Y) - ndims(X)
627+
X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X))
628+
else
629+
dims
630+
end
631+
project = ProjectTo(X)
632+
function stack_pullback(Δ)
633+
dY = unthunk(Δ)
634+
dY isa AbstractZero && return (NoTangent(), dY)
635+
dX = collect(eachslice(dY; dims = sdims))
636+
return (NoTangent(), project(reshape(dX, project.axes)))
637+
end
638+
return Y, stack_pullback
639+
end

test/rulesets/Base/array.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,33 @@ end
416416
B = hcat(A[:,:,1], A[:,:,1])
417417
@test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
418418
end
419+
420+
@testset "stack" begin
421+
# vector container
422+
xs = [rand(3, 4), rand(3, 4)]
423+
test_frule(stack, xs)
424+
test_frule(stack, xs; fkwargs=(dims=1,))
425+
426+
test_rrule(stack, xs, check_inferred=false)
427+
test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false)
428+
test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false)
429+
test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false)
430+
431+
# multidimensional container
432+
ms = [rand(2,3) for _ in 1:4, _ in 1:5];
433+
434+
if VERSION > v"1.9-" # this needs new eachslice, not yet in Compat
435+
test_rrule(stack, ms, check_inferred=false)
436+
end
437+
test_rrule(stack, ms, fkwargs=(dims=1,), check_inferred=false)
438+
test_rrule(stack, ms, fkwargs=(dims=3,), check_inferred=false)
439+
440+
# non-array inner objects
441+
ts = [Tuple(rand(3)) for _ in 1:4, _ in 1:2];
442+
443+
if VERSION > v"1.9-"
444+
test_rrule(stack, ts, check_inferred=false)
445+
end
446+
test_rrule(stack, ts, fkwargs=(dims=1,), check_inferred=false)
447+
test_rrule(stack, ts, fkwargs=(dims=2,), check_inferred=false)
448+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils
55
using Adapt
66
using Base.Broadcast: broadcastable
77
using ChainRules
8+
using ChainRules: stack
89
using ChainRulesCore
910
using ChainRulesTestUtils
1011
using ChainRulesTestUtils: rand_tangent, _fdm

0 commit comments

Comments
 (0)