|
1 | | -#!/usr/bin/python |
2 | | -# -*- coding: utf-8 -*- |
| 1 | +"""Assignment - making a sklearn estimator and cv splitter. |
| 2 | +
|
| 3 | +The goal of this assignment is to implement by yourself: |
| 4 | +
|
| 5 | +- a scikit-learn estimator for the KNearestNeighbors for classification |
| 6 | + tasks and check that it is working properly. |
| 7 | +- a scikit-learn CV splitter where the splits are based on a Pandas |
| 8 | + DateTimeIndex. |
| 9 | +
|
| 10 | +Detailed instructions for question 1: |
| 11 | +The nearest neighbor classifier predicts for a point X_i the target y_k of |
| 12 | +the training sample X_k which is the closest to X_i. We measure proximity with |
| 13 | +the Euclidean distance. The model will be evaluated with the accuracy (average |
| 14 | +number of samples corectly classified). You need to implement the `fit`, |
| 15 | +`predict` and `score` methods for this class. The code you write should pass |
| 16 | +the test we implemented. You can run the tests by calling at the root of the |
| 17 | +repo `pytest test_sklearn_questions.py`. Note that to be fully valid, a |
| 18 | +scikit-learn estimator needs to check that the input given to `fit` and |
| 19 | +`predict` are correct using the `check_*` functions imported in the file. |
| 20 | +You can find more information on how they should be used in the following doc: |
| 21 | +https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator. |
| 22 | +Make sure to use them to pass `test_nearest_neighbor_check_estimator`. |
| 23 | +
|
| 24 | +
|
| 25 | +Detailed instructions for question 2: |
| 26 | +The data to split should contain the index or one column in |
| 27 | +datatime format. Then the aim is to split the data between train and test |
| 28 | +sets when for each pair of successive months, we learn on the first and |
| 29 | +predict of the following. For example if you have data distributed from |
| 30 | +november 2020 to march 2021, you have have 4 splits. The first split |
| 31 | +will allow to learn on november data and predict on december data, the |
| 32 | +second split to learn december and predict on january etc. |
| 33 | +
|
| 34 | +We also ask you to respect the pep8 convention: https://pep8.org. This will be |
| 35 | +enforced with `flake8`. You can check that there is no flake8 errors by |
| 36 | +calling `flake8` at the root of the repo. |
| 37 | +
|
| 38 | +Finally, you need to write docstrings for the methods you code and for the |
| 39 | +class. The docstring will be checked using `pydocstyle` that you can also |
| 40 | +call at the root of the repo. |
| 41 | +
|
| 42 | +Hints |
| 43 | +----- |
| 44 | +- You can use the function: |
| 45 | +
|
| 46 | +from sklearn.metrics.pairwise import pairwise_distances |
| 47 | +
|
| 48 | +to compute distances between 2 sets of samples. |
| 49 | +""" |
3 | 50 | import numpy as np |
4 | 51 | import pandas as pd |
5 | 52 |
|
6 | | -import pandas.api.types as pdtypes |
7 | | - |
8 | 53 | from sklearn.base import BaseEstimator |
9 | 54 | from sklearn.base import ClassifierMixin |
10 | 55 |
|
11 | 56 | from sklearn.model_selection import BaseCrossValidator |
12 | | -from sklearn.utils.multiclass import unique_labels |
13 | | -from sklearn.utils.validation import validate_data, check_is_fitted |
14 | 57 |
|
15 | | -from collections import Counter |
| 58 | +from sklearn.utils.validation import (check_X_y, check_is_fitted, |
| 59 | + validate_data) |
| 60 | +from sklearn.utils.multiclass import check_classification_targets |
| 61 | +from sklearn.metrics.pairwise import pairwise_distances |
16 | 62 |
|
17 | 63 |
|
18 | 64 | class KNearestNeighbors(ClassifierMixin, BaseEstimator): |
| 65 | + """KNearestNeighbors classifier. |
19 | 66 |
|
20 | | - """KNearestNeighbors classifier.""" |
| 67 | + This class implements a K-Nearest Neighbors classifier for classification |
| 68 | + tasks. The classifier predicts the label of a test point based on the |
| 69 | + majority class of its nearest neighbors in the training dataset. |
21 | 70 |
|
22 | | - def __init__(self, num_neighbors=1): # noqa: D107 |
23 | | - self.num_neighbors = num_neighbors |
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + n_neighbors : int, default=1 |
| 74 | + Number of neighbors to use for classification. |
| 75 | + """ |
24 | 76 |
|
25 | | - def fit(self, features, labels): |
| 77 | + def __init__(self, n_neighbors=1): # noqa: D107 |
| 78 | + """Initialize the classifier with the specified number of neighbors.""" |
| 79 | + self.n_neighbors = n_neighbors |
| 80 | + |
| 81 | + def fit(self, X, y): |
26 | 82 | """Fitting function. |
27 | 83 |
|
| 84 | + This method stores the training data and labels for later use |
| 85 | + during prediction. |
| 86 | +
|
28 | 87 | Parameters |
29 | 88 | ---------- |
30 | | - features : ndarray, shape (n_samples, n_features) |
| 89 | + X : ndarray, shape (n_samples, n_features) |
31 | 90 | Data to train the model. |
32 | | - labels : ndarray, shape (n_samples,) |
| 91 | + y : ndarray, shape (n_samples,) |
33 | 92 | Labels associated with the training data. |
34 | 93 |
|
35 | 94 | Returns |
36 | 95 | ---------- |
37 | 96 | self : instance of KNearestNeighbors |
38 | | - The current instance of the classifier |
| 97 | + The fitted instance of the classifier |
39 | 98 | """ |
40 | | - |
41 | | - (features, labels) = validate_data(self, features, labels) |
42 | | - self.classes_ = unique_labels(labels) |
43 | | - self.training_features_ = features |
44 | | - self.training_labels_ = labels |
| 99 | + X, y = check_X_y(X, y) |
| 100 | + self.X_train_ = X |
| 101 | + self.y_train_ = y |
| 102 | + self.n_features_in_ = X.shape[1] |
| 103 | + check_classification_targets(y) |
| 104 | + self.classes_ = np.unique(y) |
45 | 105 | return self |
46 | 106 |
|
47 | | - def predict(self, features): |
| 107 | + def predict(self, X): |
48 | 108 | """Predict function. |
49 | 109 |
|
50 | 110 | Parameters |
51 | 111 | ---------- |
52 | | - features : ndarray, shape (n_test_samples, n_features) |
| 112 | + X : ndarray, shape (n_test_samples, n_features) |
53 | 113 | Data to predict on. |
54 | 114 |
|
55 | 115 | Returns |
56 | 116 | ---------- |
57 | | - predictions : ndarray, shape (n_test_samples,) |
| 117 | + y : ndarray, shape (n_test_samples,) |
58 | 118 | Predicted class labels for each test data sample. |
59 | 119 | """ |
60 | | - |
61 | 120 | check_is_fitted(self) |
62 | | - features = validate_data(self, features, reset=False) |
63 | | - |
64 | | - predictions = np.full(features.shape[0], self.training_labels_[0]) |
65 | | - for idx in range(features.shape[0]): |
66 | | - feature = features[idx] |
67 | | - neighbor_labels = [] |
68 | | - |
69 | | - distances = np.sum( |
70 | | - (self.training_features_ - feature) ** 2, axis=1 |
71 | | - ) |
72 | | - nearest_indices = np.argpartition( |
73 | | - distances, self.num_neighbors |
74 | | - )[: self.num_neighbors] |
75 | | - for neighbor_idx in nearest_indices: |
76 | | - neighbor_labels += [self.training_labels_[neighbor_idx]] |
77 | | - |
78 | | - predictions[idx] = Counter(neighbor_labels).most_common(1)[0][0] |
79 | | - return predictions |
80 | | - |
81 | | - def score(self, features, labels): |
| 121 | + X = validate_data(self, X, reset=False) |
| 122 | + distances = pairwise_distances(X, self.X_train_) |
| 123 | + nearest_neighbors = np.argsort(distances, axis=1)[:, :self.n_neighbors] |
| 124 | + unique_classes, y_indices = np.unique(self.y_train_, |
| 125 | + return_inverse=True) |
| 126 | + neighbor_labels = y_indices[nearest_neighbors] |
| 127 | + y_pred = np.array([unique_classes[np.bincount(labels).argmax()] |
| 128 | + for labels in neighbor_labels]) |
| 129 | + return y_pred |
| 130 | + |
| 131 | + def score(self, X, y): |
82 | 132 | """Calculate the score of the prediction. |
83 | 133 |
|
84 | 134 | Parameters |
85 | 135 | ---------- |
86 | | - features : ndarray, shape (n_samples, n_features) |
| 136 | + X : ndarray, shape (n_samples, n_features) |
87 | 137 | Data to score on. |
88 | | - labels : ndarray, shape (n_samples,) |
89 | | - Target values. |
| 138 | + y : ndarray, shape (n_samples,) |
| 139 | + target values. |
90 | 140 |
|
91 | 141 | Returns |
92 | 142 | ---------- |
93 | | - accuracy : float |
94 | | - Accuracy of the model computed for the (features, labels) pairs. |
| 143 | + score : float |
| 144 | + Accuracy of the model computed as the |
| 145 | + mean for correctly predicted labels. |
95 | 146 | """ |
96 | | - |
97 | | - predictions = self.predict(features) |
98 | | - correct_predictions = 0 |
99 | | - for idx in range(features.shape[0]): |
100 | | - if labels[idx] == predictions[idx]: |
101 | | - correct_predictions += 1 |
102 | | - return correct_predictions / features.shape[0] |
| 147 | + y_pred = self.predict(X) |
| 148 | + return np.mean(y_pred == y) |
103 | 149 |
|
104 | 150 |
|
105 | 151 | class MonthlySplit(BaseCrossValidator): |
106 | | - |
107 | 152 | """CrossValidator based on monthly split. |
108 | 153 |
|
109 | | - Split data based on the given `time_column` (or default to index). |
110 | | - Each split corresponds to one month of data for the training |
111 | | - and the next month of data for the test. |
| 154 | + Split data based on the given `time_col` (or default to index). Each split |
| 155 | + corresponds to one month of data for the training and the next month of |
| 156 | + data for the test. |
112 | 157 |
|
113 | 158 | Parameters |
114 | 159 | ---------- |
115 | | - time_column : str, defaults to 'index' |
| 160 | + time_col : str, defaults to 'index' |
116 | 161 | Column of the input DataFrame that will be used to split the data. This |
117 | 162 | column should be of type datetime. If split is called with a DataFrame |
118 | 163 | for which this column is not a datetime, it will raise a ValueError. |
119 | | - To use the index as column just set `time_column` to `'index'`. |
| 164 | + To use the index as column just set `time_col` to `'index'`. |
120 | 165 | """ |
121 | 166 |
|
122 | | - def __init__(self, time_column="index"): # noqa: D107 |
123 | | - self.time_column = time_column |
| 167 | + def __init__(self, time_col='index'): # noqa: D107 |
| 168 | + self.time_col = time_col |
124 | 169 |
|
125 | | - def get_n_splits( |
126 | | - self, |
127 | | - data, |
128 | | - labels=None, |
129 | | - groups=None, |
130 | | - ): |
| 170 | + def get_n_splits(self, X, y=None, groups=None): |
131 | 171 | """Return the number of splitting iterations in the cross-validator. |
132 | 172 |
|
133 | 173 | Parameters |
134 | 174 | ---------- |
135 | | - data : array-like of shape (n_samples, n_features) |
| 175 | + X : array-like of shape (n_samples, n_features) |
136 | 176 | Training data, where `n_samples` is the number of samples |
137 | 177 | and `n_features` is the number of features. |
138 | | - labels : array-like of shape (n_samples,) |
| 178 | + y : array-like of shape (n_samples,) |
139 | 179 | Always ignored, exists for compatibility. |
140 | 180 | groups : array-like of shape (n_samples,) |
141 | 181 | Always ignored, exists for compatibility. |
142 | 182 |
|
143 | 183 | Returns |
144 | 184 | ------- |
145 | | - num_splits : int |
146 | | - The number of splits. |
| 185 | + n_splits : int |
| 186 | + The number of splits based on unique months in the data. |
147 | 187 | """ |
148 | | - |
149 | | - if self.time_column == "index": |
150 | | - if not isinstance(data.index, pd.DatetimeIndex): |
151 | | - raise ValueError("datetime") |
152 | | - sorted_data = data.sort_index() |
153 | | - months = sorted_data.index.month |
| 188 | + if isinstance(X, pd.Series): |
| 189 | + times = X.index |
| 190 | + elif isinstance(X, pd.DataFrame): |
| 191 | + times = X.index if self.time_col == 'index' else X[self.time_col] |
154 | 192 | else: |
| 193 | + raise ValueError("X should be a pandas DataFrame or Series.") |
155 | 194 |
|
156 | | - if not pdtypes.is_datetime64_dtype(data[self.time_column]): |
157 | | - raise ValueError("datetime") |
158 | | - sorted_data = data.sort_values(by=self.time_column) |
159 | | - sorted_data.index = sorted_data[self.time_column] |
160 | | - months = sorted_data.index.month |
161 | | - |
162 | | - num_splits = 0 |
163 | | - for idx in range(1, len(months)): |
164 | | - if months[idx] != months[idx - 1]: |
165 | | - num_splits += 1 |
166 | | - return num_splits |
167 | | - |
168 | | - def split( |
169 | | - self, |
170 | | - data, |
171 | | - labels, |
172 | | - groups=None, |
173 | | - ): |
| 195 | + if not pd.api.types.is_datetime64_any_dtype(times): |
| 196 | + raise ValueError("time_col must be a datetime column.") |
| 197 | + periods = pd.Series(times).dt.to_period("M") |
| 198 | + return len(periods.unique()) - 1 |
| 199 | + |
| 200 | + def split(self, X, y, groups=None): |
174 | 201 | """Generate indices to split data into training and test set. |
175 | 202 |
|
176 | 203 | Parameters |
177 | 204 | ---------- |
178 | | - data : array-like of shape (n_samples, n_features) |
| 205 | + X : array-like of shape (n_samples, n_features) |
179 | 206 | Training data, where `n_samples` is the number of samples |
180 | 207 | and `n_features` is the number of features. |
181 | | - labels : array-like of shape (n_samples,) |
| 208 | + y : array-like of shape (n_samples,) |
182 | 209 | Always ignored, exists for compatibility. |
183 | 210 | groups : array-like of shape (n_samples,) |
184 | 211 | Always ignored, exists for compatibility. |
185 | 212 |
|
186 | 213 | Yields |
187 | 214 | ------ |
188 | | - train_indices : ndarray |
| 215 | + idx_train : ndarray |
189 | 216 | The training set indices for that split. |
190 | | - test_indices : ndarray |
| 217 | + idx_test : ndarray |
191 | 218 | The testing set indices for that split. |
192 | 219 | """ |
| 220 | + # Determine time column |
| 221 | + if isinstance(X, pd.DataFrame): |
| 222 | + if self.time_col == 'index': |
| 223 | + times = X.index |
| 224 | + else: |
| 225 | + times = X[self.time_col] |
| 226 | + elif isinstance(X, pd.Series): |
| 227 | + times = X.index |
| 228 | + else: |
| 229 | + raise ValueError("X should be a pandas DataFrame or Series.") |
| 230 | + |
| 231 | + # Ensure time column is datetime |
| 232 | + if not pd.api.types.is_datetime64_any_dtype(times): |
| 233 | + raise ValueError("time_col must be a datetime column.") |
193 | 234 |
|
194 | | - num_splits = self.get_n_splits(data, labels, groups) |
| 235 | + # Create a copy of the data |
| 236 | + X_copy = X.copy() |
| 237 | + y_copy = y.copy() if y is not None else None |
195 | 238 |
|
196 | | - if self.time_column == "index": |
197 | | - months_list = [sorted(data.index)[0]] |
| 239 | + # Sort the copy of the data by time |
| 240 | + if isinstance(X_copy, pd.DataFrame) and self.time_col != 'index': |
| 241 | + sorted_data = X_copy.sort_values(by=self.time_col) |
198 | 242 | else: |
| 243 | + sorted_data = X_copy.sort_index() |
199 | 244 |
|
200 | | - months_list = [sorted(data["date"])[0]] |
| 245 | + # Extract the sorted indices |
| 246 | + sorted_indices = sorted_data.index |
201 | 247 |
|
202 | | - for _ in range(num_splits): |
203 | | - months_list += [months_list[-1] + pd.DateOffset(months=1)] |
| 248 | + # Map sorted indices to original indices |
| 249 | + times = pd.Series(times.values, index=sorted_indices).sort_index() |
204 | 250 |
|
205 | | - for split_idx in range(num_splits): |
206 | | - train_month = months_list[split_idx] |
207 | | - test_month = months_list[split_idx + 1] |
208 | | - train_indices = [] |
209 | | - test_indices = [] |
| 251 | + # Sort y_copy if it exists |
| 252 | + if y_copy is not None: |
| 253 | + y_copy = y_copy.loc[sorted_indices] |
210 | 254 |
|
211 | | - for data_idx in range(len(data)): |
212 | | - if self.time_column == "index": |
213 | | - current_date = data.index[data_idx] |
214 | | - else: |
215 | | - current_date = data.iloc[data_idx]["date"] |
| 255 | + # Group by unique months |
| 256 | + periods = times.dt.to_period("M") |
| 257 | + unique_periods = sorted(periods.unique()) |
216 | 258 |
|
217 | | - if ( |
218 | | - current_date.month == train_month.month |
219 | | - and current_date.year == train_month.year |
220 | | - ): |
221 | | - train_indices.append(data_idx) |
222 | | - elif ( |
223 | | - current_date.month == test_month.month |
224 | | - and current_date.year == test_month.year |
225 | | - ): |
| 259 | + n_splits = self.get_n_splits(X_copy, y_copy, groups) |
226 | 260 |
|
227 | | - test_indices.append(data_idx) |
| 261 | + for i in range(n_splits): |
| 262 | + idx_train = np.where(periods == unique_periods[i])[0] |
| 263 | + idx_test = np.where(periods == unique_periods[i + 1])[0] |
228 | 264 |
|
229 | | - yield (train_indices, test_indices) |
| 265 | + yield idx_train, idx_test |
0 commit comments