Skip to content

Commit 125205b

Browse files
oversample and undersample always return classes as well (#116)
1 parent 855f95b commit 125205b

File tree

3 files changed

+34
-26
lines changed

3 files changed

+34
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.2.12"
4+
version = "0.3.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/resample.jl

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ resulting data will be shuffled after its creation; if it is not
2121
shuffled then all the repeated samples will be together at the
2222
end, sorted by class. Defaults to `true`.
2323
24+
The output will contain both the resampled data and classes.
25+
2426
```julia
2527
# 6 observations with 3 features each
2628
X = rand(3, 6)
@@ -40,14 +42,7 @@ X_bal, Y_bal = oversample(X, Y)
4042
```
4143
4244
For this function to work, the type of `data` must implement
43-
[`numobs`](@ref) and [`getobs`](@ref). For example, the following
44-
code allows `oversample` to work on a `DataFrame`.
45-
46-
```julia
47-
# Make DataFrames.jl work
48-
MLUtils.getobs(data::DataFrame, i) = data[i,:]
49-
MLUtils.numobs(data::DataFrame) = nrow(data)
50-
```
45+
[`numobs`](@ref) and [`getobs`](@ref).
5146
5247
Note that if `data` is a tuple and `classes` is not given,
5348
then it will be assumed that the last element of the tuple contains the classes.
@@ -98,16 +93,22 @@ function oversample(data, classes; fraction=1, shuffle::Bool=true)
9893
append!(inds, inds_for_lbl)
9994
end
10095
if num_extra_needed > 0
101-
append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false))
96+
if shuffle
97+
append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false))
98+
else
99+
append!(inds, inds_for_lbl[1:num_extra_needed])
100+
end
102101
end
103102
end
104103

105104
shuffle && shuffle!(inds)
106-
return obsview(data, inds)
105+
return obsview(data, inds), obsview(classes, inds)
107106
end
108107

109-
oversample(data::Tuple; kws...) = oversample(data, data[end]; kws...)
110-
108+
function oversample(data::Tuple; kws...)
109+
d, c = oversample(data[1:end-1], data[end]; kws...)
110+
return (d..., c)
111+
end
111112

112113
"""
113114
undersample(data, classes; shuffle=true)
@@ -123,6 +124,8 @@ resulting data will be shuffled after its creation; if it is not
123124
shuffled then all the observations will be in their original
124125
order. Defaults to `false`.
125126
127+
The output will contain both the resampled data and classes.
128+
126129
```julia
127130
# 6 observations with 3 features each
128131
X = rand(3, 6)
@@ -142,14 +145,8 @@ X_bal, Y_bal = undersample(X, Y)
142145
```
143146
144147
For this function to work, the type of `data` must implement
145-
[`numobs`](@ref) and [`getobs`](@ref). For example, the following
146-
code allows `undersample` to work on a `DataFrame`.
148+
[`numobs`](@ref) and [`getobs`](@ref).
147149
148-
```julia
149-
# Make DataFrames.jl work
150-
MLUtils.getobs(data::DataFrame, i) = data[i,:]
151-
MLUtils.numobs(data::DataFrame) = nrow(data)
152-
```
153150
Note that if `data` is a tuple, then it will be assumed that the
154151
last element of the tuple contains the targets.
155152
@@ -186,11 +183,18 @@ function undersample(data, classes; shuffle::Bool=true)
186183
inds = Int[]
187184

188185
for (lbl, inds_for_lbl) in lm
189-
append!(inds, sample(inds_for_lbl, mincount; replace=false))
186+
if shuffle
187+
append!(inds, sample(inds_for_lbl, mincount; replace=false))
188+
else
189+
append!(inds, inds_for_lbl[1:mincount])
190+
end
190191
end
191192

192193
shuffle ? shuffle!(inds) : sort!(inds)
193-
return obsview(data, inds)
194+
return obsview(data, inds), obsview(classes, inds)
194195
end
195196

196-
undersample(data::Tuple; kws...) = undersample(data, data[end]; kws...)
197+
function undersample(data::Tuple; kws...)
198+
d, c = undersample(data[1:end-1], data[end]; kws...)
199+
return (d..., c)
200+
end

test/resample.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
y2 = ["c", "c", "c", "a", "b"]
55

66
o = oversample((x, ya), fraction=1, shuffle=false)
7-
@test o == oversample((x, ya), ya, shuffle=false)
7+
@test o == oversample((x, ya), ya, shuffle=false)[1]
8+
xo, yo = oversample(x, ya, shuffle=false)
9+
@test (xo, yo) == o
810
ox, oy = getobs(o)
911
@test ox isa Matrix
1012
@test oy isa Vector
@@ -15,7 +17,7 @@
1517
@test oy[1:5] == ya
1618
@test oy[6] == ya[5]
1719

18-
o = oversample((x, ya), y2, shuffle=false)
20+
o = oversample((x, ya), y2, shuffle=false)[1]
1921
ox, oy = getobs(o)
2022
@test ox isa Matrix
2123
@test oy isa Vector
@@ -35,14 +37,16 @@ end
3537
y2 = ["c", "c", "c", "a", "b"]
3638

3739
o = undersample((x, ya), shuffle=false)
40+
xo, yo = undersample(x, ya, shuffle=false)
41+
@test (xo, yo) == o
3842
ox, oy = getobs(o)
3943
@test ox isa Matrix
4044
@test oy isa Vector
4145
@test size(ox) == (2, 3)
4246
@test size(oy) == (3,)
4347
@test ox[:,3] == x[:,5]
4448

45-
o = undersample((x, ya), y2, shuffle=false)
49+
o = undersample((x, ya), y2, shuffle=false)[1]
4650
ox, oy = getobs(o)
4751
@test ox isa Matrix
4852
@test oy isa Vector

0 commit comments

Comments
 (0)