|
| 1 | +import numpy as np |
| 2 | +from scipy.linalg import lstsq |
| 3 | +from sklearn.base import BaseEstimator, TransformerMixin |
| 4 | +from sklearn.impute import SimpleImputer |
| 5 | + |
| 6 | + |
| 7 | +def find_subset_indices(X_full, X_subset, method="hash"): |
| 8 | + """ |
| 9 | + Find row indices in X_full that correspond to rows in X_subset. |
| 10 | + Supports 'hash' (fast) and 'precise' (element-wise) matching. |
| 11 | + """ |
| 12 | + if X_full.shape[1] != X_subset.shape[1]: |
| 13 | + raise ValueError( |
| 14 | + f"Feature dimensions don't match: {X_full.shape[1]} vs {X_subset.shape[1]}" |
| 15 | + ) |
| 16 | + indices = [] |
| 17 | + if method == "precise": |
| 18 | + for i, subset_row in enumerate(X_subset): |
| 19 | + matches = [ |
| 20 | + j |
| 21 | + for j, full_row in enumerate(X_full) |
| 22 | + if np.array_equal(full_row, subset_row, equal_nan=True) |
| 23 | + ] |
| 24 | + if not matches: |
| 25 | + raise ValueError(f"No matching row found for subset row {i}") |
| 26 | + indices.append(matches[0]) |
| 27 | + elif method == "hash": |
| 28 | + full_hashes = [hash(row.tobytes()) for row in X_full] |
| 29 | + for i, subset_row in enumerate(X_subset): |
| 30 | + subset_hash = hash(subset_row.tobytes()) |
| 31 | + try: |
| 32 | + indices.append(full_hashes.index(subset_hash)) |
| 33 | + except ValueError as e: |
| 34 | + raise ValueError(f"No matching row found for subset row {i}") from e |
| 35 | + else: |
| 36 | + raise ValueError(f"Unknown method '{method}'. Use 'hash' or 'precise'.") |
| 37 | + return np.array(indices) |
| 38 | + |
| 39 | + |
| 40 | +class IdentityTransformer(BaseEstimator, TransformerMixin): |
| 41 | + """A transformer that returns the input unchanged.""" |
| 42 | + |
| 43 | + def fit(self, X, y=None): |
| 44 | + return self |
| 45 | + |
| 46 | + def transform(self, X): |
| 47 | + return X |
| 48 | + |
| 49 | + |
| 50 | +class CovariateRegressor(BaseEstimator, TransformerMixin): |
| 51 | + """ |
| 52 | + Fits covariate(s) onto each feature in X and returns their residuals. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + covariate, |
| 58 | + X, |
| 59 | + pipeline=None, |
| 60 | + cross_validate=True, |
| 61 | + precise=False, |
| 62 | + unique_id_col_index=None, |
| 63 | + stack_intercept=True, |
| 64 | + ): |
| 65 | + """Regresses out a variable (covariate) from each feature in X. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + covariate : numpy array |
| 70 | + Array of length (n_samples, n_covariates) to regress out of each |
| 71 | + feature; May have multiple columns for multiple covariates. |
| 72 | + X : numpy array |
| 73 | + Array of length (n_samples, n_features), from which the covariate |
| 74 | + will be regressed. This is used to determine how the |
| 75 | + covariate-models should be cross-validated (which is necessary |
| 76 | + to use in in scikit-learn Pipelines). |
| 77 | + cross_validate : bool |
| 78 | + Whether to cross-validate the covariate-parameters (y~covariate) |
| 79 | + estimated from the train-set to the test set (cross_validate=True) |
| 80 | + or whether to fit the covariate regressor separately on the test-set |
| 81 | + (cross_validate=False). Setting this parameter to True is equivalent |
| 82 | + to "foldwise covariate regression" (FwCR) as described in our paper |
| 83 | + (https://www.biorxiv.org/content/early/2018/03/28/290684). Setting |
| 84 | + this parameter to False, however, is NOT equivalent to "whole |
| 85 | + dataset covariate regression" (WDCR) as it does not apply covariate |
| 86 | + regression to the *full* dataset, but simply refits the covariate |
| 87 | + model on the test-set. We recommend setting this parameter to True. |
| 88 | + precise: bool |
| 89 | + Transformer-objects in scikit-learn only allow to pass the data |
| 90 | + (X) and optionally the target (y) to the fit and transform methods. |
| 91 | + However, we need to index the covariate accordingly as well. To do so, |
| 92 | + we compare the X during initialization (self.X) with the X passed to |
| 93 | + fit/transform. As such, we can infer which samples are passed to the |
| 94 | + methods and index the covariate accordingly. When setting precise to |
| 95 | + True, the arrays are compared feature-wise, which is accurate, but |
| 96 | + relatively slow. When setting precise to False, it will infer the index |
| 97 | + by looking at the hash of all the features, which is much |
| 98 | + faster. Also, to aid the accuracy, we remove the features which are constant |
| 99 | + (0) across samples. |
| 100 | + stack_intercept : bool |
| 101 | + Whether to stack an intercept to the covariate (default is True) |
| 102 | +
|
| 103 | + Attributes |
| 104 | + ---------- |
| 105 | + weights_ : numpy array |
| 106 | + Array with weights for the covariate(s). |
| 107 | + """ |
| 108 | + self.covariate = covariate.astype(np.float64) |
| 109 | + self.cross_validate = cross_validate |
| 110 | + self.X = X |
| 111 | + self.precise = precise |
| 112 | + self.stack_intercept = stack_intercept |
| 113 | + self.weights_ = None |
| 114 | + self.pipeline = pipeline |
| 115 | + self.imputer = SimpleImputer(strategy="median") |
| 116 | + self.X_imputer = SimpleImputer(strategy="median") |
| 117 | + self.unique_id_col_index = unique_id_col_index |
| 118 | + |
| 119 | + def _prepare_covariate(self, covariate): |
| 120 | + """Prepare covariate matrix (adds intercept if needed)""" |
| 121 | + if self.stack_intercept: |
| 122 | + return np.c_[np.ones((covariate.shape[0], 1)), covariate] |
| 123 | + return covariate |
| 124 | + |
| 125 | + def fit(self, X, y=None): |
| 126 | + """Fits the covariate-regressor to X. |
| 127 | +
|
| 128 | + Parameters |
| 129 | + ---------- |
| 130 | + X : numpy array |
| 131 | + An array of shape (n_samples, n_features), which should correspond |
| 132 | + to your train-set only! |
| 133 | + y : None |
| 134 | + Included for compatibility; does nothing. |
| 135 | + """ |
| 136 | + |
| 137 | + # Prepare covariate matrix (adds intercept if needed) |
| 138 | + covariate = self._prepare_covariate(self.covariate) |
| 139 | + |
| 140 | + # Find indices of X subset in the original X |
| 141 | + method = "precise" if self.precise else "hash" |
| 142 | + fit_idx = find_subset_indices(self.X, X, method=method) |
| 143 | + |
| 144 | + # Remove unique ID column if specified |
| 145 | + if self.unique_id_col_index is not None: |
| 146 | + X = np.delete(X, self.unique_id_col_index, axis=1) |
| 147 | + |
| 148 | + # Extract covariate data for the fitting subset |
| 149 | + covariate_fit = covariate[fit_idx, :] |
| 150 | + |
| 151 | + # Conditional imputation for covariate data |
| 152 | + if np.isnan(covariate_fit).any(): |
| 153 | + covariate_fit = self.imputer.fit_transform(covariate_fit) |
| 154 | + else: |
| 155 | + # Still fit the imputer for consistency in transform |
| 156 | + self.imputer.fit(covariate_fit) |
| 157 | + |
| 158 | + # Apply pipeline transformation if specified |
| 159 | + if self.pipeline is not None: |
| 160 | + X = self.pipeline.fit_transform(X) |
| 161 | + |
| 162 | + # Conditional imputation for X |
| 163 | + if np.isnan(X).any(): |
| 164 | + X = self.X_imputer.fit_transform(X) |
| 165 | + else: |
| 166 | + # Still fit the imputer for consistency in transform |
| 167 | + self.X_imputer.fit(X) |
| 168 | + |
| 169 | + # Fit linear regression: X = covariate * weights + residuals |
| 170 | + # Using scipy's lstsq for numerical stability |
| 171 | + self.weights_ = lstsq(covariate_fit, X)[0] |
| 172 | + |
| 173 | + return self |
| 174 | + |
| 175 | + def transform(self, X): |
| 176 | + """Regresses out covariate from X. |
| 177 | +
|
| 178 | + Parameters |
| 179 | + ---------- |
| 180 | + X : numpy array |
| 181 | + An array of shape (n_samples, n_features), which should correspond |
| 182 | + to your train-set only! |
| 183 | +
|
| 184 | + Returns |
| 185 | + ------- |
| 186 | + X_new : ndarray |
| 187 | + ndarray with covariate-regressed features |
| 188 | + """ |
| 189 | + |
| 190 | + if not self.cross_validate: |
| 191 | + self.fit(X) |
| 192 | + |
| 193 | + # Prepare covariate matrix (adds intercept if needed) |
| 194 | + covariate = self._prepare_covariate(self.covariate) |
| 195 | + |
| 196 | + # Find indices of X subset in the original X |
| 197 | + method = "precise" if self.precise else "hash" |
| 198 | + transform_idx = find_subset_indices(self.X, X, method=method) |
| 199 | + |
| 200 | + # Remove unique ID column if specified |
| 201 | + if self.unique_id_col_index is not None: |
| 202 | + X = np.delete(X, self.unique_id_col_index, axis=1) |
| 203 | + |
| 204 | + # Extract covariate data for the transform subset |
| 205 | + covariate_transform = covariate[transform_idx] |
| 206 | + |
| 207 | + # Conditional imputation for covariate data (use fitted imputer) |
| 208 | + if np.isnan(covariate_transform).any(): |
| 209 | + covariate_transform = self.imputer.transform(covariate_transform) |
| 210 | + |
| 211 | + # Apply pipeline transformation if specified |
| 212 | + if self.pipeline is not None: |
| 213 | + X = self.pipeline.transform(X) |
| 214 | + |
| 215 | + # Conditional imputation for X (use fitted imputer) |
| 216 | + if np.isnan(X).any(): |
| 217 | + X = self.X_imputer.transform(X) |
| 218 | + |
| 219 | + # Compute residuals |
| 220 | + X_new = X - covariate_transform.dot(self.weights_) |
| 221 | + |
| 222 | + # Ensure no NaNs in output |
| 223 | + X_new = np.nan_to_num(X_new) |
| 224 | + |
| 225 | + return X_new |
0 commit comments