|
1 | | -[](https://github.com/cmccomb/automl/actions) |
| 1 | +[](https://github.com/cmccomb/automl/actions) |
2 | 2 | [](https://crates.io/crates/automl) |
3 | 3 | [](https://docs.rs/automl) |
4 | 4 |
|
5 | | -## What & Why |
6 | | -`automl` automates model selection and training on top of the `smartcore` machine learning library, helping Rust developers quickly build regression, classification, and clustering models. |
| 5 | +# `automl` with `smartcore` |
7 | 6 |
|
8 | | -## Quickstart |
9 | | -Install from [crates.io](https://crates.io/crates/automl) or use the GitHub repository for the latest changes: |
| 7 | +`AutoML` (_Automated Machine Learning_) streamlines machine learning workflows, making them more accessible and efficient |
| 8 | +for users of all experience levels. This crate extends the [`smartcore`](https://docs.rs/smartcore/) machine learning |
| 9 | +framework, providing utilities to quickly train, compare, and deploy models. |
10 | 10 |
|
11 | | -```toml |
12 | | -# Cargo.toml |
13 | | -[dependencies] |
14 | | -automl = "0.2.9" |
15 | | -``` |
16 | | - |
17 | | -```toml |
18 | | -# Cargo.toml |
19 | | -[dependencies] |
20 | | -automl = { git = "https://github.com/cmccomb/rust-automl" } |
21 | | -``` |
22 | | - |
23 | | -```rust |
24 | | -use automl::{RegressionModel, RegressionSettings}; |
25 | | -use smartcore::linalg::basic::matrix::DenseMatrix; |
26 | | - |
27 | | -let x = DenseMatrix::from_2d_vec(&vec![ |
28 | | - vec![1.0_f64, 2.0, 3.0], |
29 | | - vec![2.0, 3.0, 4.0], |
30 | | - vec![3.0, 4.0, 5.0], |
31 | | -]).unwrap(); |
32 | | -let y = vec![1.0_f64, 2.0, 3.0]; |
33 | | -let _model = RegressionModel::new(x, y, RegressionSettings::default()); |
34 | | -``` |
35 | | - |
36 | | -Support Vector Regression can be enabled alongside the default algorithms and tuned with a |
37 | | -kernel-specific configuration: |
38 | | - |
39 | | -```rust |
40 | | -use automl::settings::{Kernel, SVRParameters}; |
41 | | -use automl::RegressionSettings; |
42 | | -use smartcore::linalg::basic::matrix::DenseMatrix; |
43 | | - |
44 | | -let settings: RegressionSettings<f64, f64, DenseMatrix<f64>, Vec<f64>> = |
45 | | - RegressionSettings::default().with_svr_settings( |
46 | | - SVRParameters::default() |
47 | | - .with_eps(0.2) |
48 | | - .with_tol(1e-4) |
49 | | - .with_c(2.0) |
50 | | - .with_kernel(Kernel::RBF(0.4)), |
51 | | -); |
52 | | -``` |
53 | | - |
54 | | -Gradient boosting via Smartcore's `XGBoost` implementation is also available, giving access to |
55 | | -learning-rate, depth, and subsampling knobs: |
56 | | - |
57 | | -```rust |
58 | | -use automl::settings::XGRegressorParameters; |
59 | | -use automl::{DenseMatrix, RegressionSettings}; |
60 | | - |
61 | | -let settings: RegressionSettings<f64, f64, DenseMatrix<f64>, Vec<f64>> = |
62 | | - RegressionSettings::default().with_xgboost_settings( |
63 | | - XGRegressorParameters::default() |
64 | | - .with_n_estimators(75) |
65 | | - .with_learning_rate(0.15) |
66 | | - .with_max_depth(4) |
67 | | - .with_subsample(0.9), |
68 | | -); |
69 | | -``` |
70 | | - |
71 | | -Extremely randomized trees offer another ensemble option that leans into randomness for lower |
72 | | -variance models: |
73 | | - |
74 | | -```rust |
75 | | -use automl::settings::ExtraTreesRegressorParameters; |
76 | | -use automl::{DenseMatrix, RegressionSettings}; |
77 | | - |
78 | | -let settings: RegressionSettings<f64, f64, DenseMatrix<f64>, Vec<f64>> = |
79 | | - RegressionSettings::default().with_extra_trees_settings( |
80 | | - ExtraTreesRegressorParameters::default() |
81 | | - .with_n_trees(50) |
82 | | - .with_min_samples_leaf(2) |
83 | | - .with_keep_samples(true) |
84 | | - .with_seed(7), |
85 | | -); |
86 | | -``` |
87 | | - |
88 | | -Unlike the random forest regressor, the Extra Trees variant grows each tree on the full training |
89 | | -set and samples split thresholds uniformly rather than optimizing them. The parameter |
90 | | -`with_keep_samples(true)` is particularly useful here: because there is no bootstrapping, enabling |
91 | | -it stores the original observations so that out-of-bag style diagnostics remain possible. You can |
92 | | -also adjust `with_m(...)` to change how many random features are considered at each split—doing so |
93 | | -directly influences the amount of randomness introduced by the split selection compared with the |
94 | | -random forest estimator. |
95 | | - |
96 | | -### Loading data from CSV |
| 11 | +# Install |
97 | 12 |
|
98 | | -Use `load_labeled_csv` to read a dataset and separate the target column: |
| 13 | +Add `automl` to your `Cargo.toml` to get started: |
99 | 14 |
|
100 | | -```rust |
101 | | -use automl::{RegressionModel, RegressionSettings}; |
102 | | -use automl::utils::load_labeled_csv; |
| 15 | +**Stable Version** |
103 | 16 |
|
104 | | -let (x, y) = load_labeled_csv("tests/fixtures/supervised_sample.csv", 2).unwrap(); |
105 | | -let mut model = RegressionModel::new(x, y, RegressionSettings::default()); |
| 17 | +```toml |
| 18 | +automl = "0.2.9" |
106 | 19 | ``` |
107 | 20 |
|
108 | | -Use `load_csv_features` to read unlabeled data for clustering: |
109 | | - |
110 | | -```rust |
111 | | -use automl::{ClusteringModel}; |
112 | | -use automl::settings::ClusteringSettings; |
113 | | -use automl::utils::load_csv_features; |
114 | | - |
115 | | -let x = load_csv_features("tests/fixtures/clustering_points.csv").unwrap(); |
116 | | -let mut model = ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2)); |
117 | | -model.train(); |
118 | | -let clusters: Vec<u8> = model.predict(&x).unwrap(); |
119 | | -``` |
| 21 | +**Latest Development Version** |
120 | 22 |
|
121 | | -## Examples |
122 | | -### Classification |
123 | | -```rust |
124 | | -use automl::{ClassificationModel}; |
125 | | -use automl::settings::{ClassificationSettings, RandomForestClassifierParameters}; |
126 | | -use smartcore::linalg::basic::matrix::DenseMatrix; |
127 | | - |
128 | | -let x = DenseMatrix::from_2d_vec(&vec![ |
129 | | - vec![0.0_f64, 0.0], |
130 | | - vec![1.0, 1.0], |
131 | | - vec![1.0, 0.0], |
132 | | - vec![0.0, 1.0], |
133 | | -]).unwrap(); |
134 | | -let y = vec![0_u32, 1, 1, 0]; |
135 | | -let settings = ClassificationSettings::default() |
136 | | - .with_random_forest_classifier_settings( |
137 | | - RandomForestClassifierParameters::default().with_n_trees(10), |
138 | | - ); |
139 | | -let _model = ClassificationModel::new(x, y, settings); |
| 23 | +```toml |
| 24 | +automl = { git = "https://github.com/cmccomb/rust-automl" } |
140 | 25 | ``` |
141 | 26 |
|
142 | | -Multinomial Naive Bayes is available for datasets where every feature represents a non-negative |
143 | | -integer count. You can opt into it alongside the other classifiers when your data meets that |
144 | | -requirement: |
| 27 | +# Example Usage |
145 | 28 |
|
146 | | -```rust |
147 | | -use automl::settings::{ClassificationSettings, MultinomialNBParameters}; |
| 29 | +Here’s a quick example to illustrate how `AutoML` can simplify model training and comparison: |
148 | 30 |
|
149 | | -let settings = ClassificationSettings::default() |
150 | | - .with_multinomial_nb_settings(MultinomialNBParameters::default()); |
| 31 | +```rust, no_run, ignore |
| 32 | +let dataset = smartcore::dataset::breast_cancer::load_dataset(); |
| 33 | +let settings = automl::Settings::default_classification(); |
| 34 | +let mut classifier = automl::SupervisedModel::new(dataset, settings); |
| 35 | +classifier.train(); |
151 | 36 | ``` |
152 | 37 |
|
153 | | -If the feature matrix includes fractional or negative values, the Multinomial NB variant will |
154 | | -emit a descriptive error explaining the constraint. |
155 | | - |
156 | | -Bernoulli Naive Bayes supports binary features and can also binarize continuous inputs when you |
157 | | -provide a threshold. Set `binarize` to `None` to require pre-binarized inputs, or configure the |
158 | | -threshold to map values above it to `1` and the rest to `0` during training and prediction: |
159 | | - |
160 | | -```rust |
161 | | -use automl::settings::{BernoulliNBParameters, ClassificationSettings}; |
| 38 | +will perform a comparison of classifier models using cross-validation. Printing the classifier object will yield: |
162 | 39 |
|
163 | | -let mut params = BernoulliNBParameters::default(); |
164 | | -params.binarize = None; // ensure features are already 0/1 encoded |
165 | | -let settings = ClassificationSettings::default().with_bernoulli_nb_settings(params); |
166 | | - |
167 | | -// alternatively, binarize values greater than 0.5 |
168 | | -let thresholded = ClassificationSettings::default().with_bernoulli_nb_settings( |
169 | | - BernoulliNBParameters::default().with_binarize(0.5), |
170 | | -); |
| 40 | +```text |
| 41 | +┌────────────────────────────────┬─────────────────────┬───────────────────┬──────────────────┐ |
| 42 | +│ Model │ Time │ Training Accuracy │ Testing Accuracy │ |
| 43 | +╞════════════════════════════════╪═════════════════════╪═══════════════════╪══════════════════╡ |
| 44 | +│ Random Forest Classifier │ 835ms 393us 583ns │ 1.00 │ 0.96 │ |
| 45 | +├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
| 46 | +│ Logistic Regression Classifier │ 620ms 714us 583ns │ 0.97 │ 0.95 │ |
| 47 | +├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
| 48 | +│ Gaussian Naive Bayes │ 6ms 529us │ 0.94 │ 0.93 │ |
| 49 | +├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
| 50 | +│ Categorical Naive Bayes │ 2ms 922us 250ns │ 0.96 │ 0.93 │ |
| 51 | +├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
| 52 | +│ Decision Tree Classifier │ 15ms 404us 750ns │ 1.00 │ 0.93 │ |
| 53 | +├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
| 54 | +│ KNN Classifier │ 28ms 874us 208ns │ 0.96 │ 0.92 │ |
| 55 | +├────────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
| 56 | +│ Support Vector Classifier │ 4s 187ms 61us 708ns │ 0.57 │ 0.57 │ |
| 57 | +└────────────────────────────────┴─────────────────────┴───────────────────┴──────────────────┘ |
171 | 58 | ``` |
172 | 59 |
|
173 | | -### Clustering |
174 | | -```rust |
175 | | -use automl::ClusteringModel; |
176 | | -use automl::settings::ClusteringSettings; |
177 | | -use smartcore::linalg::basic::matrix::DenseMatrix; |
178 | | - |
179 | | -let x = DenseMatrix::from_2d_vec(&vec![ |
180 | | - vec![1.0_f64, 1.0], |
181 | | - vec![1.2, 0.8], |
182 | | - vec![8.0, 8.0], |
183 | | - vec![8.2, 8.2], |
184 | | -]).unwrap(); |
185 | | -let mut model = ClusteringModel::new(x.clone(), ClusteringSettings::default().with_k(2)); |
186 | | -model.train(); |
187 | | -let truth = vec![1_u8, 1, 2, 2]; |
188 | | -model.evaluate(&truth); |
189 | | -println!("{model}"); |
190 | | - |
191 | | -for algorithm in model.trained_algorithm_names() { |
192 | | - let clusters: Vec<u8> = model.predict_with(algorithm, &x).expect("prediction"); |
193 | | - println!("{algorithm}: {clusters:?}"); |
194 | | -} |
195 | | -``` |
| 60 | +You can then perform inference using the best model with the `predict` method. |
196 | 61 |
|
197 | | -Additional runnable examples are available in the [examples/ directory](https://github.com/cmccomb/rust-automl/tree/main/examples), |
198 | | -including [minimal_classification.rs](https://github.com/cmccomb/rust-automl/blob/main/examples/minimal_classification.rs), |
199 | | -[maximal_classification.rs](https://github.com/cmccomb/rust-automl/blob/main/examples/maximal_classification.rs), |
200 | | -[minimal_regression.rs](https://github.com/cmccomb/rust-automl/blob/main/examples/minimal_regression.rs), |
201 | | -[maximal_regression.rs](https://github.com/cmccomb/rust-automl/blob/main/examples/maximal_regression.rs), |
202 | | -[minimal_clustering.rs](https://github.com/cmccomb/rust-automl/blob/main/examples/minimal_clustering.rs), and |
203 | | -[maximal_clustering.rs](https://github.com/cmccomb/rust-automl/blob/main/examples/maximal_clustering.rs). |
| 62 | +## Features |
204 | 63 |
|
205 | | -Model comparison: |
| 64 | +This crate has several features that add some additional methods. |
206 | 65 |
|
207 | | -```text |
208 | | -┌───────────────────────────────┬─────────────────────┬───────────────────┬──────────────────┐ |
209 | | -│ Model │ Time │ Training Accuracy │ Testing Accuracy │ |
210 | | -╞═══════════════════════════════╪═════════════════════╪═══════════════════╪══════════════════╡ |
211 | | -│ Random Forest Classifier │ 835ms 393us 583ns │ 1.00 │ 0.96 │ |
212 | | -├───────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
213 | | -│ Decision Tree Classifier │ 15ms 404us 750ns │ 1.00 │ 0.93 │ |
214 | | -├───────────────────────────────┼─────────────────────┼───────────────────┼──────────────────┤ |
215 | | -│ KNN Classifier │ 28ms 874us 208ns │ 0.96 │ 0.92 │ |
216 | | -└───────────────────────────────┴─────────────────────┴───────────────────┴──────────────────┘ |
217 | | -``` |
| 66 | +| Feature | Description | |
| 67 | +|:--------|:--------------------------------------------------------------------------------------------------------| |
| 68 | +| `nd` | Adds methods for predicting/reading data using [`ndarray`](https://crates.io/crates/ndarray). | |
| 69 | +| `csv` | Adds methods for predicting/reading data from a .csv using [`polars`](https://crates.io/crates/polars). | |
218 | 70 |
|
219 | 71 | ## Capabilities |
220 | | -- Feature Engineering: PCA, SVD, interaction terms, polynomial terms |
221 | | -- Regression: Decision Tree, KNN, Random Forest, Extra Trees, Linear, Ridge, LASSO, Elastic Net, Support Vector Regression, `XGBoost` Gradient Boosting |
222 | | -- Classification: Random Forest, Decision Tree, KNN, Logistic Regression, Support Vector Classifier, Gaussian Naive Bayes, Categorical Naive Bayes, Bernoulli Naive Bayes (binary features or configurable thresholding), Categorical Naive Bayes, Multinomial Naive Bayes (non-negative integer features) |
223 | | -- Clustering: K-Means, Agglomerative, DBSCAN |
224 | | -- Meta-learning: Blending (experimental) |
225 | | -- Persistence: Save/load settings and models |
226 | | - |
227 | | -## Development |
228 | | -Before submitting changes, run: |
229 | | - |
230 | | -```sh |
231 | | -cargo fmt --all -- --check |
232 | | -cargo clippy --all-targets -- -D warnings |
233 | | -cargo test |
234 | | -cargo audit |
235 | | -cargo test --doc |
236 | | -``` |
237 | | - |
238 | | -Security audits run weekly via a scheduled workflow, but running `cargo audit` locally before submitting changes helps catch issues earlier. |
239 | | - |
240 | | -Pull requests are welcome! |
241 | 72 |
|
242 | | -## License |
243 | | -Licensed under the MIT OR Apache-2.0 license. |
| 73 | +- Feature Engineering |
| 74 | + - PCA |
| 75 | + - SVD |
| 76 | + - Interaction terms |
| 77 | + - Polynomial terms |
| 78 | +- Regression |
| 79 | + - Decision Tree Regression |
| 80 | + - KNN Regression |
| 81 | + - Random Forest Regression |
| 82 | + - Linear Regression |
| 83 | + - Ridge Regression |
| 84 | + - LASSO |
| 85 | + - Elastic Net |
| 86 | + - Support Vector Regression |
| 87 | +- Classification |
| 88 | + - Random Forest Classification |
| 89 | + - Decision Tree Classification |
| 90 | + - Support Vector Classification |
| 91 | + - Logistic Regression |
| 92 | + - KNN Classification |
| 93 | + - Gaussian Naive Bayes |
| 94 | +- Meta-learning |
| 95 | + - Blending |
| 96 | +- Save and load settings |
| 97 | +- Save and load models |
0 commit comments