-
Notifications
You must be signed in to change notification settings - Fork 9
API 2: CFI, PFI, LOCO #372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #372 +/- ##
==========================================
+ Coverage 98.10% 98.19% +0.09%
==========================================
Files 22 22
Lines 1159 1222 +63
==========================================
+ Hits 1137 1200 +63
Misses 22 22 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good but the diff seems very large for this small change.
Is there a reason for all the other modifications?
I reorganize a bit the parameter in the init and move the docstring to the class because in all the other classes, I plan to do this. By looking into more details, I miss some parts being added. I will add it and ask you to review it after. Sorry for it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is definitely an improvement, thx.
).pvalue | ||
return self.importances_ | ||
|
||
def fit_importance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it disturbing that fit_importance
has a behavior that is quite different from simply calling fit, then importance.
- Could we add a check to the
.fit()
method to ensure that the estimator is fitted, and if not, fit it. - Could we allow for passing a list of fitted estimators matching the number of splits? That could typically be relevant for users willing to pass DL models, trained before, through skorch for instance.
- If the models are not fitted, can we store them? in estimators_ for instance? It is useful to check the predictive performance in addition to the importance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Could we add a check to the
.fit()
method to ensure that the estimator is fitted, and if not, fit it.
From my point of view, I don't think because the goal is to do the importance into the cross validation, like in the example of plot_model_agnostic__importance
.
* Could we allow for passing a list of fitted estimators matching the number of splits? That could typically be relevant for users willing to pass DL models, trained before, through skorch for instance.
This also requires having the index of the cross validation. At this point, it's better for the user to do the loop with themselves.
* If the models are not fitted, can we store them? in estimators_ for instance? It is useful to check the predictive performance in addition to the importance.
Yes, I will add this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did the modification, tell me it's ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, there are too many patches that are not optimal and would be avoided by creating a dedicated BasePerturbationCV
:
- The computation of p-values by taking the mean over folds is not valid
- A big benefit of model-agnostic approaches (LOCO, CFI...) is to support DL models. However, it is not reasonable to force the training of DL models in the 'fit' of Hidimstat's methods. We should allow passing a list of fitted estimators to support this use case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the DL?
I agree that this is not an optimal approach. However, there will need a redesign of the usage of CV and the management of the estimator. @bthirion doesn't want that CV is a parameter of fit_importances
and for the moment, the estimator requires to be fitted before usage.
I don't see the point of having another class BasePerturbationCV if it's only modifying the fit_importances
.
Passing the list of fitted estimators and a CV can be an idea but there are difficult to assert the link between these two objects.
- The computation of p-values by taking the mean over folds is not valid
Do you have a better solution?
I was thought of using the function aggregate_pvalue
but I don't know if it's correct in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DL: Deep Learning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the progress. Please find a few suggestions enclosed.
Co-authored-by: bthirion <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're almost there.
Attributes | ||
---------- | ||
features_groups : dict | ||
Mapping of feature groups identified during fit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is no longer accurate IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean?
It's still accurate in this version of the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you.
I agree with the goal of the modifications but I believe that it is not optimal to implement the CV by simply patching the current class, we need a dedicated class BasePerturbationCV
self.importances_ = np.mean(self.importances_cv_, axis=0) | ||
self.pvalues_ = ( | ||
None if self.pvalues_cv_[0] is None else np.mean(self.pvalues_cv_, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks problematic:
- The p-value of the CV estimator is computed over the k test statistics, where k is the number of folds. So, self.pvalues_cv_` should we 1d. Even if it were 2d, taking the mean of p-values is not in general a p-value.
- I see the problem that leaving self.pvalues_ to None will leave the instance "not-fitted" for me this calls for creating a sub-class
BasePerturbationCV
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I see the problem that leaving self.pvalues_ to None will leave the instance "not-fitted" for me this calls for creating a sub-class
BasePerturbationCV
This is not a problem because pvalues are not possible to be computed by all methods, such as LOCO.
The check is based only on importances_
).pvalue | ||
return self.importances_ | ||
|
||
def fit_importance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, there are too many patches that are not optimal and would be avoided by creating a dedicated BasePerturbationCV
:
- The computation of p-values by taking the mean over folds is not valid
- A big benefit of model-agnostic approaches (LOCO, CFI...) is to support DL models. However, it is not reasonable to force the training of DL models in the 'fit' of Hidimstat's methods. We should allow passing a list of fitted estimators to support this use case.
Co-authored-by: Joseph Paillard <[email protected]>
self.pvalues_ = ttest_1samp( | ||
test_result, 0.0, axis=1, alternative="greater" | ||
).pvalue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the issue #48 mentions, do I should propose a better way to compute the pvalue?
If you want to use the function as parameters:
Do you have some suggestions for the signature of the function of it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest something similar to scikit learn's metric: support both strings: 'ttest', 'wilcoxon', 'corrected-ttest' ... and functions lambda x: ttest_1samp(x, 0.0, axis=1, alternative="greater")[1]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A signature is more like this:
test(diff_loss) -> pvalue
Do you think that losses - mean_losses
has parameter is enough or do we need more information?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature you describe looks good. However, I think it is nice to also support passing a string for classical tests. That would save to the user the process of defining a function that follows the signature described while fixing the other parameters of the function (e.g. axis=1, alternative="greater" ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like to have strings because I find difficult to manage but in this case, it can be interesting.
I try to add a function for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the point is that ttest_1samp
is initially a scipy method. Keeping a similar API helps users.
Update the model of CFI, PFI and LOCO for API 2.