Skip to content

Commit fdd2aa7

Browse files
committed
force_all_finite support added
1 parent 309ef4f commit fdd2aa7

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

modAL/models/base.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class BaseLearner(ABC, BaseEstimator):
3030
for instance, modAL.uncertainty.uncertainty_sampling.
3131
X_training: Initial training samples, if available.
3232
y_training: Initial training labels corresponding to initial training samples.
33+
force_all_finite: When True, forces all values of the data finite.
34+
When False, accepts np.nan and np.inf values.
3335
bootstrap_init: If initial training data is available, bootstrapping can be done during the first training.
3436
Useful when building Committee models with bagging.
3537
**fit_kwargs: keyword arguments.
@@ -47,6 +49,7 @@ def __init__(self,
4749
X_training: Optional[modALinput] = None,
4850
y_training: Optional[modALinput] = None,
4951
bootstrap_init: bool = False,
52+
force_all_finite: bool = True,
5053
**fit_kwargs
5154
) -> None:
5255
assert callable(query_strategy), 'query_strategy must be callable'
@@ -59,6 +62,9 @@ def __init__(self,
5962
if X_training is not None:
6063
self._fit_to_known(bootstrap=bootstrap_init, **fit_kwargs)
6164

65+
assert isinstance(force_all_finite, bool), 'force_all_finite must be a bool'
66+
self.force_all_finite = force_all_finite
67+
6268
def _add_training_data(self, X: modALinput, y: modALinput) -> None:
6369
"""
6470
Adds the new data and label to the known data, but does not retrain the model.
@@ -71,7 +77,8 @@ def _add_training_data(self, X: modALinput, y: modALinput) -> None:
7177
If the classifier has been fitted, the features in X have to agree with the training samples which the
7278
classifier has seen.
7379
"""
74-
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None)
80+
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None,
81+
force_all_finite=self.force_all_finite)
7582

7683
if self.X_training is None:
7784
self.X_training = X
@@ -117,7 +124,8 @@ def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **f
117124
Returns:
118125
self
119126
"""
120-
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None)
127+
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None,
128+
force_all_finite=self.force_all_finite)
121129

122130
if not bootstrap:
123131
self.estimator.fit(X, y, **fit_kwargs)
@@ -146,7 +154,8 @@ def fit(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwarg
146154
Returns:
147155
self
148156
"""
149-
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None)
157+
check_X_y(X, y, accept_sparse=True, ensure_2d=False, allow_nd=True, multi_output=True, dtype=None,
158+
force_all_finite=self.force_all_finite)
150159
self.X_training, self.y_training = X, y
151160
return self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
152161

0 commit comments

Comments
 (0)