Skip to content

Commit 08ad0b7

Browse files
improvements to stack (#125)
* improvements to stack * cleanup * use Base definition of stack * v0.4 * use stack in batch * Compat bound
1 parent 1c50c62 commit 08ad0b7

File tree

6 files changed

+36
-63
lines changed

6 files changed

+36
-63
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.3.1"
4+
version = "0.4.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
910
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1011
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
@@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
2021

2122
[compat]
2223
ChainRulesCore = "1.0"
24+
Compat = "4.2"
2325
DataAPI = "1.0"
2426
DelimitedFiles = "1.0"
2527
FLoops = "0.2"

src/MLUtils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import NNlib
2121

2222
@traitdef IsTable{X}
2323
@traitimpl IsTable{X} <- Tables.istable(X)
24-
24+
25+
using Compat: stack
2526

2627
include("observation.jl")
2728
export numobs,
@@ -75,7 +76,7 @@ export batch,
7576
rand_like,
7677
randn_like,
7778
rpad_constant,
78-
stack,
79+
stack, # in Base since julia v1.9
7980
unbatch,
8081
unsqueeze,
8182
unstack,

src/deprecations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Deprecated in v0.2
2-
@deprecate stack(x, dims) stack(x; dims=dims)
32
@deprecate unstack(x, dims) unstack(x; dims=dims)
43
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
54
@deprecate unsqueeze(dims::Int) unsqueeze(dims=dims)

src/utils.jl

Lines changed: 5 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
Return `x` reshaped into an array one dimensionality higher than `x`,
77
where `dims` indicates in which dimension `x` is extended.
8+
`dims` can be an integer between 1 and `ndims(x)+1`.
89
910
See also [`flatten`](@ref), [`stack`](@ref).
1011
@@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
3334
[1, 2] [3, 4] [5, 6]
3435
```
3536
"""
36-
function unsqueeze(x::AbstractArray; dims::Int)
37-
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1)
37+
function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N}
38+
@assert 1 <= dims <= N + 1
39+
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1)
3840
return reshape(x, sz)
3941
end
4042

@@ -55,51 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)
5557

5658
Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")")
5759

58-
"""
59-
stack(xs; dims)
60-
61-
Concatenate the given array of arrays `xs` into a single array along the
62-
given dimension `dims`.
63-
64-
See also [`stack`](@ref) and [`batch`](@ref).
65-
66-
# Examples
67-
68-
```jldoctest
69-
julia> xs = [[1, 2], [3, 4], [5, 6]]
70-
3-element Vector{Vector{Int64}}:
71-
[1, 2]
72-
[3, 4]
73-
[5, 6]
74-
75-
julia> stack(xs, dims=1)
76-
3×2 Matrix{Int64}:
77-
1 2
78-
3 4
79-
5 6
80-
81-
julia> stack(xs, dims=2)
82-
2×3 Matrix{Int64}:
83-
1 3 5
84-
2 4 6
85-
86-
julia> stack(xs, dims=3)
87-
2×1×3 Array{Int64, 3}:
88-
[:, :, 1] =
89-
1
90-
2
91-
92-
[:, :, 2] =
93-
3
94-
4
95-
96-
[:, :, 3] =
97-
5
98-
6
99-
```
100-
"""
101-
stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims)
102-
10360
"""
10461
unstack(xs; dims)
10562
@@ -329,17 +286,7 @@ end
329286

330287
batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)
331288

332-
function batch(xs::AbstractArray{<:AbstractArray})
333-
# Don't use stack(xs, dims=N+1), it is much slower.
334-
# Here we do reduce(vcat, xs) along with some reshapes.
335-
szxs = size(xs)
336-
@assert length(xs) > 0 "Minimum batch size is 1."
337-
szx = size(xs[1])
338-
@assert all(x -> size(x) == szx, xs) "All arrays must be of the same size."
339-
vxs = vec(vec.(xs))
340-
y = reduce(vcat, vxs)
341-
return reshape(y, szx..., szxs...)
342-
end
289+
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)
343290

344291
function batch(xs::Vector{<:Tuple})
345292
@assert length(xs) > 0 "Input should be non-empty"

test/test_utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11

2+
"""
3+
Test gradients through zygote.
4+
5+
# Arguments
6+
7+
- `f`: function to test
8+
- `xs`: inputs to `f`
9+
10+
# Keyword Arguments
11+
Keyword arguments are passed to `rrule`.
12+
13+
- `fkwargs`: keyword arguments to `f`
14+
"""
215
function test_zygote(f, xs...; kws...)
316
config = ZygoteRuleConfig()
417
test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad)

test/utils.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,30 @@
66
@test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1)
77

88
@test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2)
9+
10+
@test_throws AssertionError unsqueeze(rand(2,2), dims=4)
911
end
1012

1113
@testset "stack and unstack" begin
1214
x = randn(3,3)
1315
stacked = stack([x, x], dims=2)
1416
@test size(stacked) == (3,2,3)
15-
@test_broken @inferred(stack([x, x], dims=2)) == stacked
17+
@test @inferred(stack([x, x], dims=2)) == stacked
1618

1719
stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
1820
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
1921
@test unstack(stacked_array, dims=2) == unstacked_array
2022
@test stack(unstacked_array, dims=2) == stacked_array
2123
@test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array
24+
25+
for d in (1,2,3)
26+
test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false)
27+
end
28+
29+
# Issue #121
30+
a = [[1] for i in 1:10000]
31+
@test size(stack(a, dims=1)) == (10000, 1)
32+
@test size(stack(a, dims=2)) == (1, 10000)
2233
end
2334

2435
@testset "batch and unbatch" begin

0 commit comments

Comments
 (0)