Skip to content

Commit 379fce6

Browse files
jpaillardbthirionlionelkusch
authored
D0CRT Accept Scikit-learn models directly instead of parameter dictionaries (#370)
* [skip doc] remove dict and add tests * [skip doc] remove useless lines * adapt examples * [build doc] improve doc explamation * add test for error * try to cover missing line * move screening out of the class so that it can be re-used * set to 0 non selected coef * set coef to 0 in classification * fix docstring * a decision_function to the list of prediction methods * add transform to the list * remove transform * remove sentence on d1CRT * fix regression typo * add line * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: bthirion <[email protected]> * Update test/test_distilled_conditional_randomization_test.py Co-authored-by: bthirion <[email protected]> * Update test/test_distilled_conditional_randomization_test.py Co-authored-by: bthirion <[email protected]> * Update test/test_distilled_conditional_randomization_test.py Co-authored-by: bthirion <[email protected]> * fix all typos * set default random_state to None * formatting * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * underscore after selection_set * fix test * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * remove is_lasso att * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * set underscore attributes to None in init, update check_fit * remove coefficient from check_fit * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * format * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * Update src/hidimstat/distilled_conditional_randomization_test.py Co-authored-by: lionel kusch <[email protected]> * remove lasso_model_ from attributes * add docstring --------- Co-authored-by: bthirion <[email protected]> Co-authored-by: lionel kusch <[email protected]>
1 parent 57d40de commit 379fce6

File tree

4 files changed

+286
-300
lines changed

4 files changed

+286
-300
lines changed

examples/plot_dcrt_example.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@
5858
y = np.maximum(0.0, y)
5959

6060
## dcrt Lasso ##
61-
d0crt_lasso = D0CRT(estimator=LassoCV(random_state=42, n_jobs=1), screening=False)
61+
d0crt_lasso = D0CRT(
62+
estimator=LassoCV(random_state=42, n_jobs=1), screening_threshold=None
63+
)
6264
d0crt_lasso.fit_importance(X, y)
6365
pvals_lasso = d0crt_lasso.pvalues_
6466
results_list.append(
@@ -73,7 +75,7 @@
7375
## dcrt Random Forest ##
7476
d0crt_random_forest = D0CRT(
7577
estimator=RandomForestRegressor(n_estimators=100, random_state=42, n_jobs=1),
76-
screening=False,
78+
screening_threshold=None,
7779
)
7880
d0crt_random_forest.fit_importance(X, y)
7981
pvals_forest = d0crt_random_forest.pvalues_

examples/plot_model_agnostic_importance.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from sklearn.model_selection import KFold
3636
from sklearn.svm import SVC
3737

38-
from hidimstat import LOCO, D0CRT
38+
from hidimstat import D0CRT, LOCO
3939

4040
#############################################################################
4141
# Generate data where classes are not linearly separable
@@ -70,11 +70,11 @@
7070
# test (:math:`H_0: X_j \perp\!\!\!\perp y | X_{-j}`) for each variable. However,
7171
# this test is based on a linear model (LogisticRegression) and fails to reject the null
7272
# in the presence of non-linear relationships.
73-
d0crt_linear = D0CRT(estimator=clone(linear_model), screening=False)
73+
d0crt_linear = D0CRT(estimator=clone(linear_model), screening_threshold=None)
7474
d0crt_linear.fit_importance(X, y)
7575
pval_dcrt_linear = d0crt_linear.pvalues_
7676

77-
d0crt_non_linear = D0CRT(estimator=clone(non_linear_model), screening=False)
77+
d0crt_non_linear = D0CRT(estimator=clone(non_linear_model), screening_threshold=None)
7878
d0crt_non_linear.fit_importance(X, y)
7979
pval_dcrt_non_linear = d0crt_non_linear.pvalues_
8080

@@ -162,7 +162,8 @@
162162
# As expected, when using linear models (d0CRT and LOCO-linear) that are misspecified,
163163
# the varibles are not selected. This highlights the benefit of using model-agnostic
164164
# methods such as LOCO, which allows for the use of models that are expressive enough
165-
# to explain the data.
165+
# to explain the data. While d0CRT can use any estimator, its distillation step
166+
# restricts it from capturing variable interactions.
166167

167168

168169
#################################################################################

0 commit comments

Comments
 (0)