@@ -77,12 +77,27 @@ function MMI.fit(
7777 return fitresult, cache, report
7878end
7979
80+ # returns a dictionary of categorical elements keyed on ref integer:
8081get_encoding (classes_seen) = Dict (MMI. int (c) => c for c in classes (classes_seen))
8182
82- MMI. fitted_params (:: DecisionTreeClassifier , fitresult) =
83- (tree= fitresult[1 ],
84- encoding= get_encoding (fitresult[2 ]),
85- features= fitresult[4 ])
83+ # given such a dictionary, return printable class labels, ordered by corresponding ref
84+ # integer:
85+ classlabels (encoding) = [string (encoding[i]) for i in sort (keys (encoding) |> collect)]
86+
87+ _node_or_leaf (r:: DecisionTree.Root ) = r. node
88+ _node_or_leaf (n:: Any ) = n
89+
90+ function MMI. fitted_params (:: DecisionTreeClassifier , fitresult)
91+ raw_tree = fitresult[1 ]
92+ encoding = get_encoding (fitresult[2 ])
93+ features = fitresult[4 ]
94+ classlabels = MLJDecisionTreeInterface. classlabels (encoding)
95+ tree = DecisionTree. wrap (
96+ _node_or_leaf (raw_tree),
97+ (featurenames= features, classlabels),
98+ )
99+ (; tree, raw_tree, encoding, features)
100+ end
86101
87102function MMI. predict (m:: DecisionTreeClassifier , fitresult, Xnew)
88103 tree, classes_seen, integers_seen = fitresult
@@ -285,13 +300,22 @@ function MMI.fit(m::DecisionTreeRegressor, verbosity::Int, Xmatrix, y, features)
285300 cache = nothing
286301
287302 report = (features= features,)
303+ fitresult = (tree, features)
288304
289- return tree , cache, report
305+ return fitresult , cache, report
290306end
291307
292- MMI. fitted_params (:: DecisionTreeRegressor , tree) = (tree= tree,)
308+ function MMI. fitted_params (:: DecisionTreeRegressor , fitresult)
309+ raw_tree = fitresult[1 ]
310+ features = fitresult[2 ]
311+ tree = DecisionTree. wrap (
312+ _node_or_leaf (raw_tree),
313+ (; featurenames= features),
314+ )
315+ (; tree, raw_tree)
316+ end
293317
294- MMI. predict (:: DecisionTreeRegressor , tree , Xnew) = DT. apply_tree (tree , Xnew)
318+ MMI. predict (:: DecisionTreeRegressor , fitresult , Xnew) = DT. apply_tree (fitresult[ 1 ] , Xnew)
295319
296320MMI. reports_feature_importances (:: Type{<:DecisionTreeRegressor} ) = true
297321
@@ -446,11 +470,11 @@ MMI.selectrows(::TreeModel, I, Xmatrix) = (view(Xmatrix, I, :),)
446470
447471# get actual arguments needed for importance calculation from various fitresults.
448472get_fitresult (
449- m:: Union{DecisionTreeClassifier, RandomForestClassifier} ,
473+ m:: Union{DecisionTreeClassifier, RandomForestClassifier, DecisionTreeRegressor } ,
450474 fitresult,
451475) = (fitresult[1 ],)
452476get_fitresult (
453- m:: Union{DecisionTreeRegressor, RandomForestRegressor} ,
477+ m:: RandomForestRegressor ,
454478 fitresult,
455479) = (fitresult,)
456480get_fitresult (m:: AdaBoostStumpClassifier , fitresult)= (fitresult[1 ], fitresult[2 ])
@@ -561,7 +585,7 @@ where
561585Train the machine using `fit!(mach, rows=...)`.
562586
563587
564- # Hyper-parameters
588+ # Hyperparameters
565589
566590- `max_depth=-1`: max depth of the decision tree (-1=any)
567591
@@ -600,12 +624,14 @@ Train the machine using `fit!(mach, rows=...)`.
600624
601625The fields of `fitted_params(mach)` are:
602626
603- - `tree`: the tree or stump object returned by the core DecisionTree.jl algorithm
627+ - `raw_tree`: the raw `Node`, `Leaf` or `Root` object returned by the core DecisionTree.jl
628+ algorithm
629+
630+ - `tree`: a visualizable, wrapped version of `raw_tree` implementing the AbstractTrees.jl
631+ interface; see "Examples" below
604632
605633- `encoding`: dictionary of target classes keyed on integers used
606- internally by DecisionTree.jl; needed to interpret pretty printing
607- of tree (obtained by calling `fit!(mach, verbosity=2)` or from
608- report - see below)
634+ internally by DecisionTree.jl
609635
610636- `features`: the names of the features encountered in training, in an
611637 order consistent with the output of `print_tree` (see below)
@@ -617,23 +643,28 @@ The fields of `report(mach)` are:
617643
618644- `classes_seen`: list of target classes actually observed in training
619645
620- - `print_tree`: method to print a pretty representation of the fitted
646+ - `print_tree`: alternative method to print the fitted
621647 tree, with single argument the tree depth; interpretation requires
622648 internal integer-class encoding (see "Fitted parameters" above).
623649
624650- `features`: the names of the features encountered in training, in an
625651 order consistent with the output of `print_tree` (see below)
626652
653+ # Accessor functions
654+
655+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
656+ the type of importance is determined by the hyperparameter `feature_importance` (see
657+ above)
627658
628659# Examples
629660
630661```
631662using MLJ
632- Tree = @load DecisionTreeClassifier pkg=DecisionTree
633- tree = Tree (max_depth=4 , min_samples_split=3)
663+ DecisionTreeClassifier = @load DecisionTreeClassifier pkg=DecisionTree
664+ model = DecisionTreeClassifier (max_depth=3 , min_samples_split=3)
634665
635666X, y = @load_iris
636- mach = machine(tree , X, y) |> fit!
667+ mach = machine(model , X, y) |> fit!
637668
638669Xnew = (sepal_length = [6.4, 7.2, 7.4],
639670 sepal_width = [2.8, 3.0, 2.8],
@@ -643,33 +674,26 @@ yhat = predict(mach, Xnew) # probabilistic predictions
643674predict_mode(mach, Xnew) # point predictions
644675pdf.(yhat, "virginica") # probabilities for the "verginica" class
645676
646- fitted_params(mach).tree # raw tree or stump object from DecisionTrees.jl
647-
648- julia> report(mach).print_tree(3)
649- Feature 4, Threshold 0.8
650- L-> 1 : 50/50
651- R-> Feature 4, Threshold 1.75
652- L-> Feature 3, Threshold 4.95
653- L->
654- R->
655- R-> Feature 3, Threshold 4.85
656- L->
657- R-> 3 : 43/43
658- ```
659-
660- To interpret the internal class labelling:
661-
662- ```
663- julia> fitted_params(mach).encoding
664- Dict{CategoricalArrays.CategoricalValue{String, UInt32}, UInt32} with 3 entries:
665- 0x00000003 => "virginica"
666- 0x00000001 => "setosa"
667- 0x00000002 => "versicolor"
677+ julia> tree = fitted_params(mach).tree
678+ petal_length < 2.45
679+ ├─ setosa (50/50)
680+ └─ petal_width < 1.75
681+ ├─ petal_length < 4.95
682+ │ ├─ versicolor (47/48)
683+ │ └─ virginica (4/6)
684+ └─ petal_length < 4.85
685+ ├─ virginica (2/3)
686+ └─ virginica (43/43)
687+
688+ using Plots, TreeRecipe
689+ plot(tree) # for a graphical representation of the tree
690+
691+ feature_importances(mach)
668692```
669693
670- See also
671- [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and
672- the unwrapped model type [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
694+ See also [DecisionTree.jl](https://github.com/bensadeghi/DecisionTree.jl) and the
695+ unwrapped model type
696+ [`MLJDecisionTreeInterface.DecisionTree.DecisionTreeClassifier`](@ref).
673697
674698"""
675699DecisionTreeClassifier
@@ -699,7 +723,7 @@ where
699723Train the machine with `fit!(mach, rows=...)`.
700724
701725
702- # Hyper-parameters
726+ # Hyperparameters
703727
704728- `max_depth=-1`: max depth of the decision tree (-1=any)
705729
@@ -744,6 +768,13 @@ The fields of `fitted_params(mach)` are:
744768- `features`: the names of the features encountered in training
745769
746770
771+ # Accessor functions
772+
773+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
774+ the type of importance is determined by the hyperparameter `feature_importance` (see
775+ above)
776+
777+
747778# Examples
748779
749780```
@@ -800,7 +831,7 @@ where:
800831Train the machine with `fit!(mach, rows=...)`.
801832
802833
803- # Hyper-parameters
834+ # Hyperparameters
804835
805836- `n_iter=10`: number of iterations of AdaBoost
806837
@@ -834,6 +865,15 @@ The fields of `fitted_params(mach)` are:
834865- `features`: the names of the features encountered in training
835866
836867
868+ # Accessor functions
869+
870+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
871+ the type of importance is determined by the hyperparameter `feature_importance` (see
872+ above)
873+
874+
875+ # Examples
876+
837877```
838878using MLJ
839879Booster = @load AdaBoostStumpClassifier pkg=DecisionTree
@@ -852,6 +892,7 @@ pdf.(yhat, "virginica") # probabilities for the "verginica" class
852892
853893fitted_params(mach).stumps # raw `Ensemble` object from DecisionTree.jl
854894fitted_params(mach).coefs # coefficient associated with each stump
895+ feature_importances(mach)
855896```
856897
857898See also
@@ -886,7 +927,7 @@ where
886927Train the machine with `fit!(mach, rows=...)`.
887928
888929
889- # Hyper-parameters
930+ # Hyperparameters
890931
891932- `max_depth=-1`: max depth of the decision tree (-1=any)
892933
@@ -903,7 +944,8 @@ Train the machine with `fit!(mach, rows=...)`.
903944- `merge_purity_threshold=1.0`: (post-pruning) merge leaves having
904945 combined purity `>= merge_purity_threshold`
905946
906- - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
947+ - `feature_importance`: method to use for computing feature importances. One of
948+ `(:impurity, :split)`
907949
908950- `rng=Random.GLOBAL_RNG`: random number generator or seed
909951
@@ -921,26 +963,50 @@ The fields of `fitted_params(mach)` are:
921963- `tree`: the tree or stump object returned by the core
922964 DecisionTree.jl algorithm
923965
966+ - `features`: the names of the features encountered in training
967+
924968
925969# Report
926970
927971- `features`: the names of the features encountered in training
928972
929973
974+ # Accessor functions
975+
976+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
977+ the type of importance is determined by the hyperparameter `feature_importance` (see
978+ above)
979+
980+
930981# Examples
931982
932983```
933984using MLJ
934- Tree = @load DecisionTreeRegressor pkg=DecisionTree
935- tree = Tree (max_depth=4 , min_samples_split=3)
985+ DecisionTreeRegressor = @load DecisionTreeRegressor pkg=DecisionTree
986+ model = DecisionTreeRegressor (max_depth=3 , min_samples_split=3)
936987
937- X, y = make_regression(100, 2 ) # synthetic data
938- mach = machine(tree , X, y) |> fit!
988+ X, y = make_regression(100, 4; rng=123 ) # synthetic data
989+ mach = machine(model , X, y) |> fit!
939990
940- Xnew, _ = make_regression(3, 2)
991+ Xnew, _ = make_regression(3, 2; rng=123 )
941992yhat = predict(mach, Xnew) # new predictions
942993
943- fitted_params(mach).tree # raw tree or stump object from DecisionTree.jl
994+ julia> fitted_params(mach).tree
995+ x1 < 0.2758
996+ ├─ x2 < 0.9137
997+ │ ├─ x1 < -0.9582
998+ │ │ ├─ 0.9189256882087312 (0/12)
999+ │ │ └─ -0.23180616021065256 (0/38)
1000+ │ └─ -1.6461153800037722 (0/9)
1001+ └─ x1 < 1.062
1002+ ├─ x2 < -0.4969
1003+ │ ├─ -0.9330755147107384 (0/5)
1004+ │ └─ -2.3287967825015548 (0/17)
1005+ └─ x2 < 0.4598
1006+ ├─ -2.931299926506291 (0/11)
1007+ └─ -4.726518740473489 (0/8)
1008+
1009+ feature_importances(mach) # get feature importances
9441010```
9451011
9461012See also
@@ -975,24 +1041,25 @@ where
9751041Train the machine with `fit!(mach, rows=...)`.
9761042
9771043
978- # Hyper-parameters
1044+ # Hyperparameters
9791045
980- - `max_depth=-1`: max depth of the decision tree (-1=any)
1046+ - `max_depth=-1`: max depth of the decision tree (-1=any)
9811047
982- - `min_samples_leaf=1`: min number of samples each leaf needs to have
1048+ - `min_samples_leaf=1`: min number of samples each leaf needs to have
9831049
984- - `min_samples_split=2`: min number of samples needed for a split
1050+ - `min_samples_split=2`: min number of samples needed for a split
9851051
9861052- `min_purity_increase=0`: min purity needed for a split
9871053
9881054- `n_subfeatures=-1`: number of features to select at random (0 for all,
9891055 -1 for square root of number of features)
9901056
991- - `n_trees=10`: number of trees to train
1057+ - `n_trees=10`: number of trees to train
9921058
9931059- `sampling_fraction=0.7` fraction of samples to train each tree on
9941060
995- - `feature_importance`: method to use for computing feature importances. One of `(:impurity, :split)`
1061+ - `feature_importance`: method to use for computing feature importances. One of
1062+ `(:impurity, :split)`
9961063
9971064- `rng=Random.GLOBAL_RNG`: random number generator or seed
9981065
@@ -1015,6 +1082,13 @@ The fields of `fitted_params(mach)` are:
10151082- `features`: the names of the features encountered in training
10161083
10171084
1085+ # Accessor functions
1086+
1087+ - `feature_importances(mach)` returns a vector of `(feature::Symbol => importance)` pairs;
1088+ the type of importance is determined by the hyperparameter `feature_importance` (see
1089+ above)
1090+
1091+
10181092# Examples
10191093
10201094```
@@ -1029,6 +1103,7 @@ Xnew, _ = make_regression(3, 2)
10291103yhat = predict(mach, Xnew) # new predictions
10301104
10311105fitted_params(mach).forest # raw `Ensemble` object from DecisionTree.jl
1106+ feature_importances(mach)
10321107```
10331108
10341109See also
0 commit comments