Skip to content

Commit 448e4c3

Browse files
authored
Merge pull request #61 from JuliaAI/dev
For a 0.4.3 release
2 parents 7e39bac + 277bcff commit 448e4c3

File tree

4 files changed

+10
-14
lines changed

4 files changed

+10
-14
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ jobs:
1717
fail-fast: false
1818
matrix:
1919
version:
20-
- '1.6'
20+
- '1.10'
2121
- '1'
22-
- 'nightly'
2322
os:
2423
- ubuntu-latest
2524
arch:
@@ -30,7 +29,7 @@ jobs:
3029
with:
3130
version: ${{ matrix.version }}
3231
arch: ${{ matrix.arch }}
33-
- uses: actions/cache@v1
32+
- uses: julia-actions/cache@v2
3433
env:
3534
cache-name: cache-artifacts
3635
with:

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJDecisionTreeInterface"
22
uuid = "c6f25543-311c-4c74-83dc-3ea6d1015661"
33
authors = ["Anthony D. Blaom <[email protected]>"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
@@ -11,11 +11,11 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1212

1313
[compat]
14-
CategoricalArrays = "0.10"
14+
CategoricalArrays = "1"
1515
DecisionTree = "0.12"
1616
MLJModelInterface = "1.5"
1717
Tables = "1.6"
18-
julia = "1.6"
18+
julia = "1.10"
1919

2020
[extras]
2121
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"

src/MLJDecisionTreeInterface.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ end
2323
Base.show(stream::IO, c::TreePrinter) =
2424
print(stream, "TreePrinter object (call with display depth)")
2525

26-
function classes(y)
27-
p = CategoricalArrays.pool(y)
28-
[p[i] for i in 1:length(p)]
29-
end
3026

3127
# # DECISION TREE CLASSIFIER
3228

@@ -79,7 +75,7 @@ function MMI.fit(
7975
end
8076

8177
# returns a dictionary of categorical elements keyed on ref integer:
82-
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in classes(classes_seen))
78+
get_encoding(classes_seen) = Dict(MMI.int(c) => c for c in levels(classes_seen))
8379

8480
# given such a dictionary, return printable class labels, ordered by corresponding ref
8581
# integer:
@@ -459,7 +455,7 @@ _columnnames(X, ::Val{false}) = Tables.columnnames(first(Tables.rows(X)))
459455

460456
# for fit:
461457
MMI.reformat(::Classifier, X, y) =
462-
(Tables.matrix(X), MMI.int(y), _columnnames(X), classes(y))
458+
(Tables.matrix(X), MMI.int(y), _columnnames(X), levels(y))
463459
MMI.reformat(::Regressor, X, y) =
464460
(Tables.matrix(X), float(y), _columnnames(X))
465461
MMI.selectrows(::TreeModel, I, Xmatrix, y, meta...) =

test/runtests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
import CategoricalArrays
33
import CategoricalArrays.categorical
4+
import CategoricalArrays.levels
45
using MLJBase
56
using StableRNGs
67
using Random
@@ -48,7 +49,7 @@ stable_rng() = StableRNGs.StableRNG(123)
4849
Xraw, yraw = @load_iris
4950
X = Tables.matrix(Xraw);
5051
y = int(yraw);
51-
_classes = MLJDecisionTreeInterface.classes(yraw)
52+
_classes = levels(yraw)
5253
features = MLJDecisionTreeInterface._columnnames(Xraw)
5354

5455
baretree = DecisionTreeClassifier(rng=stable_rng())
@@ -74,7 +75,7 @@ yhat = MLJBase.predict(baretree, fitresult, X);
7475

7576
# check preservation of levels:
7677
yyhat = predict_mode(baretree, fitresult, X[1:3, :])
77-
@test MLJBase.classes(yyhat[1]) == MLJBase.classes(yraw)
78+
@test MLJBase.levels(yyhat[1]) == MLJBase.levels(yraw)
7879

7980
# check report and fitresult fields:
8081
@test Set([:classes_seen, :print_tree, :features]) == Set(keys(report))

0 commit comments

Comments
 (0)