-
Notifications
You must be signed in to change notification settings - Fork 86
SVC multiclass #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SVC multiclass #306
Conversation
Had to close the previous pr because I sent it on the wrong branch. |
The issue with Array2 is that the type within the Array is inferred and static. If a Vec<u 32> is passed in, then the application won't be able to transform the labels to: {1, -1} in order for them to be used for binary classification. Originally, you had an assert statement to validate the data passed in had labels: -1 and 1, but now the SVC, with it being multiclass, can accept a wide variety of labels. |
Array2 is an abstraction for a 2D vector, it can be used for any instance supported by Vec, it just need to be implemented. Test Driven Development (TDD) should be followed. Every time you implement something new you have to be sure to support existing behaviour. If there is no test for the operations you are changing, you should add it. Also please add one or more tests when you implement something new. When I try to run the tests in the module I get:
This is a reference implementation as generated by my LLM, you can start from this as it is not fully implemented (ie. it is an example and needs to be implemented using generic types like TX and TY). Please note that you should not modify the existing struct but instead create new structs to handle the multiclass possibility. This implementation suggests to use a 1D Vec but you should check if it correct (if it is right there is no need to use a 2D Vec): To implement a multiclass Support Vector Classification (SVC) in Rust using smartcore, we can adopt the one-vs-one (OvO) strategy, which trains binary classifiers for each pair of classes. Here's a complete implementation: use smartcore::svm::svc::{SVC, SVCParameters};
use smartcore::linalg::{BaseVector, Matrix, MatrixTrait};
use smartcore::metrics::accuracy;
use smartcore::dataset::iris::load_dataset;
// Multiclass SVC using One-vs-One strategy
struct MulticlassSVC {
classifiers: Vec<SVC<f64, DenseMatrix<f64>, Vec<f64>>>,
classes: Vec<u32>,
}
impl MulticlassSVC {
pub fn fit(
x: &DenseMatrix<f64>,
y: &Vec<u32>,
params: &SVCParameters<f64>,
) -> Result<Self, Failed> {
let classes = y.iter().unique().sorted().collect::<Vec<_>>();
let mut classifiers = Vec::new();
// Generate all class pairs
for (i, &class1) in classes.iter().enumerate() {
for &class2 in classes.iter().skip(i + 1) {
// Filter samples for current class pair
let indices: Vec<usize> = y.iter()
.enumerate()
.filter(|(_, &c)| c == class1 || c == class2)
.map(|(i, _)| i)
.collect();
let x_filtered = x.select_rows(&indices);
let y_filtered: Vec<f64> = indices.iter()
.map(|&i| if y[i] == class1 { 1.0 } else { -1.0 })
.collect();
// Train binary classifier
let mut clf = SVC::fit(&x_filtered, &y_filtered, params.clone())?;
classifiers.push((class1, class2, clf));
}
}
Ok(Self { classifiers, classes })
}
pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<u32> {
let mut votes = vec![HashMap::new(); x.shape().0];
for (class1, class2, clf) in &self.classifiers {
let preds = clf.predict(x).unwrap();
for (i, &p) in preds.iter().enumerate() {
let vote = if p > 0.0 { *class1 } else { *class2 };
*votes[i].entry(vote).or_insert(0) += 1;
}
}
votes.iter()
.map(|v| *v.iter().max_by_key(|(_, &count)| count).unwrap().0)
.collect()
}
}
// Example usage with Iris dataset
fn main() -> Result<(), Failed> {
let iris = load_dataset();
let (x_train, x_test, y_train, y_test) = train_test_split(
&iris.data,
&iris.target,
0.8,
true,
Some(42),
);
let params = SVCParameters::default()
.with_c(200.0)
.with_kernel(smartcore::svm::Kernel::linear());
let clf = MulticlassSVC::fit(&x_train, &y_train, ¶ms)?;
let preds = clf.predict(&x_test);
println!("Accuracy: {}", accuracy(&y_test, &preds));
Ok(())
} Key implementation details:
Advantages over naive implementation:
This implementation follows smartcore's design principles by:
For production use, you'd want to add:
Footnotes |
Appreciate the guide. |
It doesn't look like the DenseMatrix has the select_rows method. |
I assume you want the multiclass struct to take in any generic Array2 object as x. |
it has the
it depends what you are going to do with the |
Alright. That helps a ton. |
I have to clone the matrix in order to create a version of the matrix that filters out the rows that don't have one of the two labels. The issue is that it creates a local copy of the matrix, but the SVC struct only accepts references. This leads to the lifetime of the matrix being shorter than the SVC which is not allowed. |
I think I found a solution. |
f292ff8
to
3f5daa4
Compare
The code doesn't work btw. Im facing an issue where the shapes of whatever vector type objects aren't the same. I did learn a lot by playing around with the code. Ig I can say Im one step ahead. |
good progress 👍 keep going |
I noticed there's a lot of parts of the code where you loop through each row of X and convert them to some type. I think that could be the reason why the SVC is so slow. |
I know it's not relevant to this PR but it's really important. |
Where in the repo would I store my dataset-related tests? |
I wrote one unit test that tests the whole process. |
Do you want the benchmark to take in the Iris dataset as input? Sorry, I haven't benchmarked before. |
tests go at the bottom of the module in the |
you can follow the current practice in the |
just ask the LLM to write tests for this module and see what it returns; it is usually good with basic usages. |
I sent a pr for the smartcore-benches repo. |
Aren't those for unit tests? What about integration tests? I assume any test involving a dataset would be an integration test. |
PR is ready fo review. |
i don't know if you can technically call those you mentioned integration tests. btw the ones involving datasets need to be flagged with the |
I created a unit test so everything is good |
Thanks for your contribution. The code is returning the correct results as compared to As specified in Also please double check that you have checked and possibly reused the traits defined in base library:
|
src/svm/svc.rs
Outdated
/// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error. | ||
/// | ||
/// # Panics | ||
/// Panics if the model has not been fitted (`self.classifiers` is `None`). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not panic but return an error (see the error harness in smartcore/src/error/mod.rs
). Panics should be limited to unknowns so that we know that something non-expected happened. Everything that is expected should return an error.
@@ -226,7 +425,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY> | |||
y: &'a Y, | |||
parameters: &'a SVCParameters<TX, TY, X, Y>, | |||
) -> Result<Self, Failed> { | |||
SVC::fit(x, y, parameters) | |||
SVC::fit(x, y, parameters, None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the signature for SVC::fit
changed? this breaks the API and breaks all code that is already developed for previous versions. Never change the existing API, this is a breaking change that requires a major release (we had a major release with v0.4.0 when we changed the API to the new stable one).
The signature for SVC::fit
is:
fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
)
and such shall remain or all the previous applications written with the library will break and cause a lot of problems to all the users already using SVC.
Please be attentive in keeping the existing API, one of the main point of smartcore is stability.
In this case it is better to have some code duplication than breaking the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX>, Y: Array1<TY>>
SupervisedEstimatorBorrow<'a, X, Y, SVCParameters<TX, TY, X, Y>> for SVC<'a, TX, TY, X, Y>
{
fn new() -> Self {
Self {
classes: Option::None,
instances: Option::None,
parameters: Option::None,
w: Option::None,
b: Option::None,
phantomdata: PhantomData,
}
}
fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
) -> Result<Self, Failed> {
SVC::fit(x, y, parameters)
}
}```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isnt this the API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your code work well and you are in the good path for a good implementation.
Although, the original API for SVC is this one
The PR changes it to:
impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2<TX> + 'a, Y: Array1<TY> + 'a>
SVC<'a, TX, TY, X, Y>
{
/// Fits SVC to your data.
/// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation.
/// * `y` - class labels
/// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values.
pub fn fit(
x: &'a X,
y: &'a Y,
parameters: &'a SVCParameters<TX, TY, X, Y>,
multiclass_config: Option<MultiClassConfig<TY>>,
)
in your code multiclass_config
is not part of the original SVC
type, so it will break all existing code.
My idea was to keep the current SVC
implementation unchanged so that we don't break existing code and to add a SVCMultiClass
that implements the multiclasses classifier without touching the existing SVC
.
So you should change your implementation for the new SVCMultiClass
without changing anything in SVC
.
You can try asking an LLM: "how do I change this code in a way that it doesn't change the existing SVC code"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check the new pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I fixed everything.
} else { | ||
let classes = y.unique(); | ||
if classes.len() != 2 { | ||
return Err(Failed::fit(&format!( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good. this is good error reporting.
05c4fd7
to
c3753c7
Compare
Implemented the multiclass feature for SVC using a one for all approach.