Skip to content

Commit 855f95b

Browse files
generalize batchseq to sequence of generic arrays (#126)
generalize batchseq to sequence of generic arrays
2 parents a85c098 + bd3bef8 commit 855f95b

File tree

6 files changed

+76
-37
lines changed

6 files changed

+76
-37
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1010
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
1111
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
12+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
1415
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
@@ -23,8 +24,9 @@ DataAPI = "1.0"
2324
DelimitedFiles = "1.0"
2425
FLoops = "0.2"
2526
FoldsThreads = "0.1"
26-
SimpleTraits = "0.9"
27+
NNlib = "0.8"
2728
ShowCases = "0.1"
29+
SimpleTraits = "0.9"
2830
StatsBase = "0.33"
2931
Tables = "1.10"
3032
Transducers = "0.4"

docs/src/api.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ obsview
3636
ObsView
3737
ones_like
3838
oversample
39-
MLUtils.rpad
4039
randobs
41-
rpad(::AbstractVector, ::Integer, ::Any)
40+
rpad_constant
4241
shuffleobs
4342
splitobs
4443
stack

src/MLUtils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
1717
NoTangent, ZeroTangent, ProjectTo
1818

1919
using SimpleTraits
20+
import NNlib
2021

2122
@traitdef IsTable{X}
2223
@traitimpl IsTable{X} <- Tables.istable(X)
@@ -73,12 +74,12 @@ export batch,
7374
ones_like,
7475
rand_like,
7576
randn_like,
77+
rpad_constant,
7678
stack,
7779
unbatch,
7880
unsqueeze,
7981
unstack,
8082
zeros_like
81-
# rpad
8283

8384
include("Datasets/Datasets.jl")
8485
using .Datasets

src/deprecations.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Deprecations v0.1
1+
# Deprecated in v0.2
22
@deprecate stack(x, dims) stack(x; dims=dims)
33
@deprecate unstack(x, dims) unstack(x; dims=dims)
44
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
@@ -7,3 +7,6 @@
77
@deprecate frequencies(x) group_counts(x)
88
@deprecate eachbatch(data, batchsize; kws...) eachobs(data; batchsize, kws...)
99
@deprecate eachbatch(data; size=1, kws...) eachobs(data; batchsize=size, kws...)
10+
11+
# Deprecated in v0.3
12+
@deprecate rpad(v::AbstractVector, n::Integer, p) rpad_constant(v, n, p)

src/utils.jl

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -388,52 +388,78 @@ unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
388388
unbatch(x::AbstractVector) = x
389389

390390
"""
391-
rpad(v::AbstractVector, n::Integer, p)
391+
batchseq(seqs, val = 0)
392392
393-
Return the given sequence padded with `p` up to a maximum length of `n`.
393+
Take a list of `N` sequences, and turn them into a single sequence where each
394+
item is a batch of `N`. Short sequences will be padded by `val`.
394395
395396
# Examples
396397
397398
```jldoctest
398-
julia> rpad([1, 2], 4, 0)
399+
julia> batchseq([[1, 2, 3], [4, 5]], 0)
400+
3-element Vector{Vector{Int64}}:
401+
[1, 4]
402+
[2, 5]
403+
[3, 0]
404+
```
405+
"""
406+
function batchseq(xs, val = 0, n = nothing)
407+
n = n === nothing ? maximum(x -> size(x, ndims(x)), xs) : n
408+
xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs]
409+
[batch([obsview(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
410+
end
411+
412+
"""
413+
rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:)
414+
415+
Return the given sequence padded with `val` along the dimensions `dims`
416+
up to a maximum length in each direction specified by `n`.
417+
418+
# Examples
419+
```jldoctest
420+
julia> rpad_constant([1, 2], 4, -1) # passing with -1 up to size 4
399421
4-element Vector{Int64}:
400422
1
401423
2
402-
0
403-
0
424+
-1
425+
-1
404426
405-
julia> rpad([1, 2, 3], 2, 0)
427+
julia> rpad_constant([1, 2, 3], 2) # no padding if length is already greater than n
406428
3-element Vector{Int64}:
407429
1
408430
2
409431
3
410-
```
411-
"""
412-
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
413-
# TODO Piracy
414-
415432
416-
"""
417-
batchseq(seqs, pad)
418-
419-
Take a list of `N` sequences, and turn them into a single sequence where each
420-
item is a batch of `N`. Short sequences will be padded by `pad`.
421-
422-
# Examples
433+
julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension
434+
4×2 Matrix{Int64}:
435+
1 2
436+
3 4
437+
0 0
438+
0 0
423439
424-
```jldoctest
425-
julia> batchseq([[1, 2, 3], [4, 5]], 0)
426-
3-element Vector{Vector{Int64}}:
427-
[1, 4]
428-
[2, 5]
429-
[3, 0]
440+
julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default
441+
4×2 Matrix{Int64}:
442+
1 2
443+
3 4
444+
0 0
445+
0 0
430446
```
431447
"""
432-
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
433-
xs_ = [rpad(x, n, pad) for x in xs]
434-
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
448+
function rpad_constant(x::AbstractArray, n::Union{Integer, Tuple}, val=0; dims=:)
449+
ns = _rpad_pads(x, n, dims)
450+
return NNlib.pad_constant(x, ns, val; dims)
451+
end
452+
453+
function _rpad_pads(x, n, dims)
454+
_dims = dims === Colon() ? (1:ndims(x)) : dims
455+
_n = n isa Integer ? ntuple(i -> n, length(_dims)) : n
456+
@assert length(_dims) == length(_n)
457+
ns = ntuple(i -> isodd(i) ? 0 : max(_n[i÷2] - size(x, _dims[i÷2]), 0), 2*length(_n))
458+
return ns
435459
end
436460

461+
@non_differentiable _rpad_pads(::Any...)
462+
437463
"""
438464
flatten(x::AbstractArray)
439465

test/utils.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,6 @@ end
129129
@test d == Dict('a' => 1, 'b' => 2)
130130
end
131131

132-
@testset "rpad" begin
133-
@test rpad([1, 2], 4, 0) == [1, 2, 0, 0]
134-
@test rpad([1, 2, 3], 2, 0) == [1,2,3]
135-
end
136-
137132
@testset "batchseq" begin
138133
bs = batchseq([[1, 2, 3], [4, 5]], 0)
139134
@test bs[1] == [1, 4]
@@ -144,6 +139,11 @@ end
144139
@test bs[1] == [1, 4]
145140
@test bs[2] == [2, 5]
146141
@test bs[3] == [3, -1]
142+
143+
batchseq([ones(2,4), zeros(2, 3), ones(2,2)]) ==[[1.0 0.0 1.0; 1.0 0.0 1.0]
144+
[1.0 0.0 1.0; 1.0 0.0 1.0]
145+
[1.0 0.0 0.0; 1.0 0.0 0.0]
146+
[1.0 0.0 0.0; 1.0 0.0 0.0]]
147147
end
148148

149149
@testset "ones_like" begin
@@ -188,3 +188,11 @@ end
188188

189189
test_zygote(fill_like, rand(5), rand(), (2, 4, 2))
190190
end
191+
192+
@testset "rpad_constant" begin
193+
@test rpad_constant([1, 2], 4, -1) == [1, 2, -1, -1]
194+
@test rpad_constant([1, 2, 3], 2) == [1, 2, 3]
195+
@test rpad_constant([1 2; 3 4], 4; dims=1) == [1 2; 3 4; 0 0; 0 0]
196+
@test rpad_constant([1 2; 3 4], 4) == [1 2 0 0; 3 4 0 0; 0 0 0 0; 0 0 0 0]
197+
@test rpad_constant([1 2; 3 4], (3, 4)) == [1 2 0 0; 3 4 0 0; 0 0 0 0]
198+
end

0 commit comments

Comments
 (0)