Skip to content

Commit a9a84ba

Browse files
authored
Merge pull request #679 from cossio/main
fix cat rrule
2 parents 5e416b7 + 2b66f58 commit a9a84ba

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.44.6"
3+
version = "1.44.7"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/array.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ end
347347
function rrule(::typeof(cat), Xs...; dims)
348348
Y = cat(Xs...; dims=dims)
349349
Base.require_one_based_indexing(Y)
350-
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
350+
_cdims = dims isa Val ? _val(dims) : dims
351+
cdims = _cdims isa Integer ? Int(_cdims) : Tuple(_cdims)
351352
ndimsY = Val(ndims(Y))
352353
sizes = map(_catsize, Xs)
353354
project_Xs = map(ProjectTo, Xs)

test/rulesets/Base/array.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ end
237237
@gpu test_rrule(cat, rand(2, 4), rand(2); fkwargs=(dims=Val(2),))
238238
test_rrule(cat, rand(), rand(2, 3); fkwargs=(dims=[1,2],))
239239
test_rrule(cat, rand(1), rand(3, 2, 1); fkwargs=(dims=(1,2),), check_inferred=false) # infers Tuple{Zero, Vector{Float64}, Any}
240+
241+
if VERSION v"1.8" # Val(tuple) dims support was added in v1.8
242+
test_rrule(cat, randn(3,2,4), randn(3,2,4); fkwargs=(dims=Val((1,2)),)) #678
243+
end
240244

241245
test_rrule(cat, rand(2, 2), rand(2, 2)'; fkwargs=(dims=1,))
242246
# inference on exotic array types

0 commit comments

Comments
 (0)