Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,40 @@ will perform a comparison of classifier models using cross-validation. Printing

You can then perform inference using the best model with the `predict` method.

## Preprocessing pipelines

`automl` now supports composable preprocessing pipelines so you can build
feature engineering recipes similar to `AutoGluon` or `caret`. Pipelines are
defined with the [`PreprocessingStep`](https://docs.rs/automl/latest/automl/settings/enum.PreprocessingStep.html)
enum and attached via either the `add_step` builder or by passing a full
[`PreprocessingPipeline`](https://docs.rs/automl/latest/automl/settings/struct.PreprocessingPipeline.html).

```rust
use automl::settings::{
ClassificationSettings, PreprocessingPipeline, PreprocessingStep, RegressionSettings,
StandardizeParams,
};
use automl::DenseMatrix;

let regression = RegressionSettings::<f64, f64, DenseMatrix<f64>, Vec<f64>>::default()
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
.add_step(PreprocessingStep::ReplaceWithPCA {
number_of_components: 5,
});

let classification = ClassificationSettings::default().with_preprocessing(
PreprocessingPipeline::new()
.add_step(PreprocessingStep::AddInteractions)
.add_step(PreprocessingStep::ReplaceWithSVD {
number_of_components: 4,
}),
);
```

Pipelines preserve the order of steps. Stateful steps such as PCA, SVD, or
standardization automatically fit during training and reuse the same fitted
state when you call `predict`.

## Features

This crate has several features that add some additional methods.
Expand Down
6 changes: 3 additions & 3 deletions examples/maximal_regression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ use automl::{
settings::{
DecisionTreeRegressorParameters, Distance, ElasticNetParameters, FinalAlgorithm,
KNNAlgorithmName, KNNParameters, KNNWeightFunction, Kernel, LassoParameters,
LinearRegressionParameters, LinearRegressionSolverName, Metric,
LinearRegressionParameters, LinearRegressionSolverName, Metric, PreprocessingStep,
RandomForestRegressorParameters, RidgeRegressionParameters, RidgeRegressionSolverName,
SVRParameters, XGRegressorParameters,
SVRParameters, StandardizeParams, XGRegressorParameters,
},
};
use regression_data::regression_testing_data;
Expand All @@ -41,7 +41,7 @@ fn main() -> Result<(), Failed> {
.with_final_model(FinalAlgorithm::Best)
.skip(RegressionAlgorithm::default_random_forest())
.sorted_by(Metric::RSquared)
// .with_preprocessing(PreProcessing::AddInteractions)
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
.with_linear_settings(
LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR),
)
Expand Down
16 changes: 11 additions & 5 deletions examples/print_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ use automl::settings::{
DecisionTreeRegressorParameters, Distance, ElasticNetParameters, ExtraTreesRegressorParameters,
FinalAlgorithm, GaussianNBParameters, KNNAlgorithmName, KNNParameters, KNNWeightFunction,
Kernel, LassoParameters, LinearRegressionParameters, LinearRegressionSolverName,
LogisticRegressionParameters, Metric, MultinomialNBParameters, Objective, PreProcessing,
RandomForestClassifierParameters, RandomForestRegressorParameters, RegressionSettings,
RidgeRegressionParameters, RidgeRegressionSolverName, SVCParameters, SVRParameters,
LogisticRegressionParameters, Metric, MultinomialNBParameters, Objective,
PreprocessingPipeline, PreprocessingStep, RandomForestClassifierParameters,
RandomForestRegressorParameters, RegressionSettings, RidgeRegressionParameters,
RidgeRegressionSolverName, SVCParameters, SVRParameters, StandardizeParams,
XGRegressorParameters,
};
use serde_json::to_string_pretty;
Expand All @@ -20,7 +21,8 @@ fn build_regression_settings() -> RegressionConfig {
.shuffle_data(true)
.verbose(true)
.sorted_by(Metric::RSquared)
.with_preprocessing(PreProcessing::AddInteractions)
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
.add_step(PreprocessingStep::AddInteractions)
.with_linear_settings(
LinearRegressionParameters::default().with_solver(LinearRegressionSolverName::QR),
)
Expand Down Expand Up @@ -99,12 +101,16 @@ fn build_regression_settings() -> RegressionConfig {
}

fn build_classification_settings() -> ClassificationSettings {
let pipeline = PreprocessingPipeline::new()
.add_step(PreprocessingStep::Standardize(StandardizeParams::default()))
.add_step(PreprocessingStep::AddInteractions);

ClassificationSettings::default()
.with_number_of_folds(6)
.shuffle_data(true)
.verbose(true)
.sorted_by(Metric::Accuracy)
.with_preprocessing(PreProcessing::AddInteractions)
.with_preprocessing(pipeline)
.with_final_model(FinalAlgorithm::Best)
.with_knn_classifier_settings(
KNNParameters::default()
Expand Down
Loading