Skip to content

Commit 7e34598

Browse files
committed
Merge branch 'dev'
2 parents 5bb7c6d + 1e8afa9 commit 7e34598

File tree

4 files changed

+38
-8
lines changed

4 files changed

+38
-8
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLJLinearModels"
22
uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692"
33
authors = ["Thibaut Lienart <[email protected]>"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -25,6 +25,7 @@ julia = "1.5"
2525
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
2626
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
2727
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
28+
MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7"
2829
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2930
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
3031
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
@@ -34,4 +35,4 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
3435
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3536

3637
[targets]
37-
test = ["DelimitedFiles", "PyCall", "ForwardDiff", "Test", "Random", "RDatasets", "RCall", "MLJBase", "StableRNGs", "DataFrames"]
38+
test = ["DelimitedFiles", "PyCall", "ForwardDiff", "Test", "Random", "RDatasets", "RCall", "MLJ", "MLJBase", "StableRNGs", "DataFrames"]

src/mlj/interface.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
6262
classes = MMI.classes(y[1])
6363
nclasses = length(classes)
6464
if nclasses < 2
65-
throw(DomainError("The target `y` needs to have two or more levels."))
65+
throw(
66+
DomainError(
67+
"The target `y` needs to have two or more levels."
68+
)
69+
)
6670
elseif nclasses == 2 && m isa LogisticClassifier
6771
# recode to ± 1
6872
yplain[yplain .== 1] .= -1
@@ -71,8 +75,11 @@ function MMI.fit(m::Union{CLF_MODELS...}, verb::Int, X, y)
7175
# in the Logistic case and not in the Multinomial {0} case!
7276
nclasses = 0
7377
end
74-
# NOTE: here the number of classes is either 0 or > 2
75-
clf = glr(m, nclasses)
78+
# NOTE: here nclasses is either
79+
# - 0 Logistic, binary
80+
# - 2 Multinomial, binary
81+
# - >2 or > 2 (Multinomial)
82+
clf = glr(m, nclasses)
7683
solver = m.solver === nothing ? _solver(clf, size(Xmatrix)) : m.solver
7784
verb > 0 && @info "Solver: $(solver)"
7885
# get the parameters
@@ -84,13 +91,16 @@ end
8491
function MMI.predict(m::Union{CLF_MODELS...}, (θ, features, classes, c), Xnew)
8592
Xmatrix = MMI.matrix(Xnew)
8693
preds = apply_X(Xmatrix, θ, c)
87-
if c > 2 # multiclass
94+
95+
if c > 2 || m isa MultinomialClassifier # multiclass
8896
preds .= softmax(preds)
89-
else # binary (necessarily c==0)
97+
98+
else # binary logistic (necessarily c==0)
9099
preds .= sigmoid.(preds)
91100
preds = hcat(1.0 .- preds, preds) # scores for -1 and 1
92101
return MMI.UnivariateFinite(classes, preds)
93102
end
103+
94104
return MMI.UnivariateFinite(classes, preds)
95105
end
96106

test/interface/extras.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,22 @@ end
2424
y = MLJBase.coerce(vcat(fill("a", 10), ["b", ]), MLJBase.Multiclass)[1:10]
2525
mach = MLJBase.machine(MultinomialClassifier(), X, y) |> MLJBase.fit!
2626
end
27+
28+
# https://github.com/JuliaAI/MLJLinearModels.jl/issues/129
29+
@testset "Crabs" begin
30+
data = MLJ.load_crabs()
31+
y_, X = MLJ.unpack(data, ==(:sp), col->col in [:FL, :RW]);
32+
y = MLJ.coerce(y_, MLJ.OrderedFactor);
33+
model = MultinomialClassifier()
34+
mach = MLJ.machine(model, X, y) |> MLJ.fit!
35+
yhat = MLJ.predict_mode(mach, X)
36+
37+
# crappy "test" but we're just testing that predict works fine
38+
@test MLJ.misclassification_rate(yhat, y) < 0.4
39+
40+
model = LogisticClassifier()
41+
mach = MLJ.machine(model, X, y) |> MLJ.fit!
42+
yhat = MLJ.predict_mode(mach, X)
43+
44+
@test MLJ.misclassification_rate(yhat, y) < 0.4
45+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MLJLinearModels, Test, LinearAlgebra
22
using Random, StableRNGs, DataFrames, ForwardDiff
3-
import MLJBase # not MLJModelInterface, to mimic the full interface
3+
import MLJ, MLJBase
44

55
DO_COMPARISONS = false; include("testutils.jl")
66

0 commit comments

Comments
 (0)