diff --git a/docs/sources/CHANGELOG.md b/docs/sources/CHANGELOG.md index c1d5ddffa..4c90e5dae 100755 --- a/docs/sources/CHANGELOG.md +++ b/docs/sources/CHANGELOG.md @@ -28,6 +28,8 @@ The CHANGELOG for the current development version is available at - `np.float_` update to support for NumPy 2.0 ([#1119](https://github.com/rasbt/mlxtend/issues/1119) via [Bot-wxt1221](https://github.com/Bot-wxt1221)) +- Added optional `method` parameter to `create_counterfactual()` in `mlxtend.evaluate.counterfactual`, enabling alternative optimization strategies. ([#1029](https://github.com/rasbt/mlxtend/issues/1029) via [dhruvi003](https://github.com/dhruvi003)) + --- diff --git a/mlxtend/evaluate/counterfactual.py b/mlxtend/evaluate/counterfactual.py index 966ab2b9b..9527608b5 100644 --- a/mlxtend/evaluate/counterfactual.py +++ b/mlxtend/evaluate/counterfactual.py @@ -19,6 +19,7 @@ def create_counterfactual( y_desired_proba=None, lammbda=0.1, random_seed=None, + method="Nelder-Mead", ): """ Implementation of the counterfactual method by Wachter et al. 2017 @@ -68,6 +69,13 @@ class probability for `y_desired`. the random number generator for selecting the inital counterfactual from `X_dataset`. + method : str (default: 'Nelder-Mead') + The optimization method to be used in `scipy.optimize.minimize`. + Examples include 'Nelder-Mead', 'BFGS', 'Powell', etc. + Refer to SciPy documentation for supported methods: + https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html + + """ if y_desired_proba is not None: use_proba = True @@ -110,7 +118,7 @@ def loss(x_counterfact, lammbda): return diff + dist(x_reference, x_counterfact) - res = minimize(loss, x_counterfact, args=(lammbda), method="Nelder-Mead") + res = minimize(loss, x_counterfact, args=(lammbda), method=method) if not res["success"]: warnings.warn(res["message"])