Skip to content

SVC multiclass #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

DanielLacina
Copy link
Contributor

Implemented the multiclass feature for SVC using a one for all approach.

@DanielLacina DanielLacina requested a review from Mec-iS as a code owner June 3, 2025 21:19
@DanielLacina
Copy link
Contributor Author

Had to close the previous pr because I sent it on the wrong branch.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 4, 2025

#305 (comment)

@DanielLacina
Copy link
Contributor Author

DanielLacina commented Jun 4, 2025

The issue with Array2 is that the type within the Array is inferred and static. If a Vec<u 32> is passed in, then the application won't be able to transform the labels to: {1, -1} in order for them to be used for binary classification. Originally, you had an assert statement to validate the data passed in had labels: -1 and 1, but now the SVC, with it being multiclass, can accept a wide variety of labels.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 9, 2025

The issue with Array2 is that the type within the Array is inferred and static. If a Vec<u 32> is passed in, then the application won't be able to transform the labels to: {1, -1} in order for them to be used for binary classification.

Array2 is an abstraction for a 2D vector, it can be used for any instance supported by Vec, it just need to be implemented.

Test Driven Development (TDD) should be followed. Every time you implement something new you have to be sure to support existing behaviour. If there is no test for the operations you are changing, you should add it. Also please add one or more tests when you implement something new.

When I try to run the tests in the module I get:

running 3 tests
test svm::svc::tests::svc_fit_predict ... ok
test svm::svc::tests::svc_fit_predict_rbf ... ok
test svm::svc::tests::svc_fit_decision_function ... FAILED

successes:

successes:
    svm::svc::tests::svc_fit_predict
    svm::svc::tests::svc_fit_predict_rbf

failures:

---- svm::svc::tests::svc_fit_decision_function stdout ----
thread 'svm::svc::tests::svc_fit_decision_function' panicked at src/svm/svc.rs:1033:9:
assertion failed: y_hat[1] < y_hat[2]

This is a reference implementation as generated by my LLM, you can start from this as it is not fully implemented (ie. it is an example and needs to be implemented using generic types like TX and TY). Please note that you should not modify the existing struct but instead create new structs to handle the multiclass possibility. This implementation suggests to use a 1D Vec but you should check if it correct (if it is right there is no need to use a 2D Vec):


To implement a multiclass Support Vector Classification (SVC) in Rust using smartcore, we can adopt the one-vs-one (OvO) strategy, which trains binary classifiers for each pair of classes. Here's a complete implementation:

use smartcore::svm::svc::{SVC, SVCParameters};
use smartcore::linalg::{BaseVector, Matrix, MatrixTrait};
use smartcore::metrics::accuracy;
use smartcore::dataset::iris::load_dataset;

// Multiclass SVC using One-vs-One strategy
struct MulticlassSVC {
    classifiers: Vec<SVC<f64, DenseMatrix<f64>, Vec<f64>>>,
    classes: Vec<u32>,
}

impl MulticlassSVC {
    pub fn fit(
        x: &DenseMatrix<f64>,
        y: &Vec<u32>,
        params: &SVCParameters<f64>,
    ) -> Result<Self, Failed> {
        let classes = y.iter().unique().sorted().collect::<Vec<_>>();
        let mut classifiers = Vec::new();

        // Generate all class pairs
        for (i, &class1) in classes.iter().enumerate() {
            for &class2 in classes.iter().skip(i + 1) {
                // Filter samples for current class pair
                let indices: Vec<usize> = y.iter()
                    .enumerate()
                    .filter(|(_, &c)| c == class1 || c == class2)
                    .map(|(i, _)| i)
                    .collect();

                let x_filtered = x.select_rows(&indices);
                let y_filtered: Vec<f64> = indices.iter()
                    .map(|&i| if y[i] == class1 { 1.0 } else { -1.0 })
                    .collect();

                // Train binary classifier
                let mut clf = SVC::fit(&x_filtered, &y_filtered, params.clone())?;
                classifiers.push((class1, class2, clf));
            }
        }

        Ok(Self { classifiers, classes })
    }

    pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<u32> {
        let mut votes = vec![HashMap::new(); x.shape().0];
        
        for (class1, class2, clf) in &self.classifiers {
            let preds = clf.predict(x).unwrap();
            
            for (i, &p) in preds.iter().enumerate() {
                let vote = if p > 0.0 { *class1 } else { *class2 };
                *votes[i].entry(vote).or_insert(0) += 1;
            }
        }

        votes.iter()
            .map(|v| *v.iter().max_by_key(|(_, &count)| count).unwrap().0)
            .collect()
    }
}

// Example usage with Iris dataset
fn main() -> Result<(), Failed> {
    let iris = load_dataset();
    let (x_train, x_test, y_train, y_test) = train_test_split(
        &iris.data,
        &iris.target,
        0.8,
        true,
        Some(42),
    );

    let params = SVCParameters::default()
        .with_c(200.0)
        .with_kernel(smartcore::svm::Kernel::linear());

    let clf = MulticlassSVC::fit(&x_train, &y_train, &params)?;
    let preds = clf.predict(&x_test);
    
    println!("Accuracy: {}", accuracy(&y_test, &preds));
    Ok(())
}

Key implementation details:

  1. OvO Strategy:
    • Creates n_classes * (n_classes - 1) / 2 binary classifiers12
    • Uses filtered subsets of data for each class pair3
  2. smartcore Integration:
    • Uses SVC with configurable parameters (C, kernel)4
    • Handles DenseMatrix input as per smartcore's data requirements[^1]
  3. Prediction Aggregation:
    • Implements voting system across all binary classifiers23
    • Handles ties by selecting first majority class

Advantages over naive implementation:

  • Maintains smartcore's API conventions[^1]
  • Avoids unsafe code and complex lifetimes
  • Compatible with smartcore's linear algebra traits
  • Supports all kernel types available in smartcore4

This implementation follows smartcore's design principles by:

  • Using pure Rust without external dependencies[^1]
  • Maintaining a Pythonic/scikit-learn-like API[^1]
  • Supporting TDD through clear input/output contracts
  • Working efficiently with smartcore's matrix types[^1]

For production use, you'd want to add:

  • Model persistence (serialization/deserialization)
  • Class weighting support
  • Parallel training of binary classifiers (using rayon)
  • More sophisticated tie-breaking strategies

Footnotes

  1. https://scikit-learn.org/stable/modules/svm.html

  2. https://www.nb-data.com/p/one-vs-all-vs-one-vs-one-which-multi 2

  3. https://www.baeldung.com/cs/svm-multiclass-classification 2

  4. http://smartcorelib.org/user_guide/supervised.html 2

@DanielLacina
Copy link
Contributor Author

Appreciate the guide.

@DanielLacina
Copy link
Contributor Author

It doesn't look like the DenseMatrix has the select_rows method.

@DanielLacina
Copy link
Contributor Author

I assume you want the multiclass struct to take in any generic Array2 object as x.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 9, 2025

It doesn't look like the DenseMatrix has the select_rows method.

it has the get_row method, see src/linalg/basic/matrix.rs. Please read the API for DenseMatrix.

I assume you want the multiclass struct to take in any generic Array2 object as x.

it depends what you are going to do with the y parameter, in this case using &Vec it is OK I guess. Just try and see.

@DanielLacina
Copy link
Contributor Author

Alright. That helps a ton.

@DanielLacina
Copy link
Contributor Author

DanielLacina commented Jun 9, 2025

I have to clone the matrix in order to create a version of the matrix that filters out the rows that don't have one of the two labels. The issue is that it creates a local copy of the matrix, but the SVC struct only accepts references. This leads to the lifetime of the matrix being shorter than the SVC which is not allowed.

@DanielLacina
Copy link
Contributor Author

I think I found a solution.

@DanielLacina
Copy link
Contributor Author

The code doesn't work btw. Im facing an issue where the shapes of whatever vector type objects aren't the same. I did learn a lot by playing around with the code. Ig I can say Im one step ahead.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 14, 2025

good progress 👍 keep going

@DanielLacina
Copy link
Contributor Author

I noticed there's a lot of parts of the code where you loop through each row of X and convert them to some type. I think that could be the reason why the SVC is so slow.

@DanielLacina
Copy link
Contributor Author

I know it's not relevant to this PR but it's really important.

@DanielLacina
Copy link
Contributor Author

Where in the repo would I store my dataset-related tests?

@DanielLacina
Copy link
Contributor Author

I wrote one unit test that tests the whole process.

@DanielLacina
Copy link
Contributor Author

Do you want the benchmark to take in the Iris dataset as input? Sorry, I haven't benchmarked before.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 14, 2025

Where in the repo would I store my dataset-related tests?

tests go at the bottom of the module in the mod tests {...} submodule, as suggested by Rust best practices and implemented in the rest of the library

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 14, 2025

Do you want the benchmark to take in the Iris dataset as input? Sorry, I haven't benchmarked before.

you can follow the current practice in the smartcore-benches repository where the other benches are. Please open a PR there for details.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 14, 2025

just ask the LLM to write tests for this module and see what it returns; it is usually good with basic usages.

@DanielLacina
Copy link
Contributor Author

I sent a pr for the smartcore-benches repo.

@DanielLacina
Copy link
Contributor Author

Where in the repo would I store my dataset-related tests?

tests go at the bottom of the module in the mod tests {...} submodule, as suggested by Rust best practices and implemented in the rest of the library

Aren't those for unit tests? What about integration tests? I assume any test involving a dataset would be an integration test.

@DanielLacina
Copy link
Contributor Author

PR is ready fo review.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 14, 2025

Aren't those for unit tests? What about integration tests? I assume any test involving a dataset would be an integration test.

i don't know if you can technically call those you mentioned integration tests. btw the ones involving datasets need to be flagged with the datasets feature, see other modules and cargo.toml

@DanielLacina
Copy link
Contributor Author

I created a unit test so everything is good

@Mec-iS
Copy link
Collaborator

Mec-iS commented Jun 15, 2025

Thanks for your contribution. The code is returning the correct results as compared to scikit-learn. Nice one 👍

As specified in CONTRIBUTING.md please run rustfmt src/*.rs and cargo clippy --all-features -- -Drust-2018-idioms -Dwarnings otherwise the CI/CD tests are failing.

Also please double check that you have checked and possibly reused the traits defined in base library:

#### linalg/traits
The traits in `src/linalg/traits` are closely linked to Linear Algebra's theoretical framework. These traits are used to specify characteristics and constraints for types accepted by various algorithms. For example these allow to define if a matrix is `QRDecomposable` and/or `SVDDecomposable`. See docstring for referencese to theoretical framework.

As above these are all traits and by definition they do not allow instantiation. They are mostly used to provide constraints for implementations. For example, the implementation for Linear Regression requires the input data `X` to be in `smartcore`'s trait system `Array2<FloatNumber> + QRDecomposable<TX> + SVDDecomposable<TX>`, a 2-D matrix that is both QR and SVD decomposable; that is what the provided strucure `linalg::arrays::matrix::DenseMatrix` happens to be: `impl<T: FloatNumber> QRDecomposable<T> for DenseMatrix<T> {};impl<T: FloatNumber> SVDDecomposable<T> for DenseMatrix<T> {}`. 

src/svm/svc.rs Outdated
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error.
///
/// # Panics
/// Panics if the model has not been fitted (`self.classifiers` is `None`).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not panic but return an error (see the error harness in smartcore/src/error/mod.rs). Panics should be limited to unknowns so that we know that something non-expected happened. Everything that is expected should return an error.

@@ -226,7 +425,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> {
SVC::fit(x, y, parameters)
SVC::fit(x, y, parameters, None)
Copy link
Collaborator

@Mec-iS Mec-iS Jun 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the signature for SVC::fit changed? this breaks the API and breaks all code that is already developed for previous versions. Never change the existing API, this is a breaking change that requires a major release (we had a major release with v0.4.0 when we changed the API to the new stable one).

The signature for SVC::fit is:

fn fit(
        x: &'a X,
        y: &'a Y,
        parameters: &'a SVCParameters<TX, TY, X, Y>,
    ) 

and such shall remain or all the previous applications written with the library will break and cause a lot of problems to all the users already using SVC.
Please be attentive in keeping the existing API, one of the main point of smartcore is stability.
In this case it is better to have some code duplication than breaking the API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
    SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
{
    fn new() -> Self {
        Self {
            classes: Option::None,
            instances: Option::None,
            parameters: Option::None,
            w: Option::None,
            b: Option::None,
            phantomdata: PhantomData,
        }
    }
    fn fit(
        x: &'a X,
        y: &'a Y,
        parameters: &'a SVCParameters<TX, TY, X, Y>,
    ) -> Result<Self, Failed> {
        SVC::fit(x, y, parameters)
    }
}```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt this the API?

Copy link
Collaborator

@Mec-iS Mec-iS Jun 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your code work well and you are in the good path for a good implementation.

Although, the original API for SVC is this one

The PR changes it to:

impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
    SVC<'a, TX, TY, X, Y>
{
    /// Fits SVC to your data.
    /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
    /// * `y` - class labels
    /// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
    pub fn fit(
        x: &'a X,
        y: &'a Y,
        parameters: &'a SVCParameters<TX, TY, X, Y>,
        multiclass_config: Option<MultiClassConfig<TY>>,
    )

in your code multiclass_config is not part of the original SVC type, so it will break all existing code.
My idea was to keep the current SVC implementation unchanged so that we don't break existing code and to add a SVCMultiClass that implements the multiclasses classifier without touching the existing SVC.

So you should change your implementation for the new SVCMultiClass without changing anything in SVC.

You can try asking an LLM: "how do I change this code in a way that it doesn't change the existing SVC code"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the new pr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I fixed everything.

} else {
let classes = y.unique();
if classes.len() != 2 {
return Err(Failed::fit(&format!(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good. this is good error reporting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants