Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The CHANGELOG for the current development version is available at

- `np.float_` update to support for NumPy 2.0 ([#1119](https://github.com/rasbt/mlxtend/issues/1119) via [Bot-wxt1221](https://github.com/Bot-wxt1221))

- Added optional `method` parameter to `create_counterfactual()` in `mlxtend.evaluate.counterfactual`, enabling alternative optimization strategies. ([#1029](https://github.com/rasbt/mlxtend/issues/1029) via [dhruvi003](https://github.com/dhruvi003))


---

Expand Down
10 changes: 9 additions & 1 deletion mlxtend/evaluate/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def create_counterfactual(
y_desired_proba=None,
lammbda=0.1,
random_seed=None,
method="Nelder-Mead",
):
"""
Implementation of the counterfactual method by Wachter et al. 2017
Expand Down Expand Up @@ -68,6 +69,13 @@ class probability for `y_desired`.
the random number generator for selecting the inital counterfactual
from `X_dataset`.

method : str (default: 'Nelder-Mead')
The optimization method to be used in `scipy.optimize.minimize`.
Examples include 'Nelder-Mead', 'BFGS', 'Powell', etc.
Refer to SciPy documentation for supported methods:
https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html


"""
if y_desired_proba is not None:
use_proba = True
Expand Down Expand Up @@ -110,7 +118,7 @@ def loss(x_counterfact, lammbda):

return diff + dist(x_reference, x_counterfact)

res = minimize(loss, x_counterfact, args=(lammbda), method="Nelder-Mead")
res = minimize(loss, x_counterfact, args=(lammbda), method=method)

if not res["success"]:
warnings.warn(res["message"])
Expand Down
Loading