Skip to content

Commit f984a01

Browse files
authored
Merge pull request #2 from clmrie/clmrie
UP my solution
2 parents 775ea14 + fb2223d commit f984a01

File tree

2 files changed

+164
-128
lines changed

2 files changed

+164
-128
lines changed

.DS_Store

6 KB
Binary file not shown.

sklearn_questions.py

Lines changed: 164 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,229 +1,265 @@
1-
#!/usr/bin/python
2-
# -*- coding: utf-8 -*-
1+
"""Assignment - making a sklearn estimator and cv splitter.
2+
3+
The goal of this assignment is to implement by yourself:
4+
5+
- a scikit-learn estimator for the KNearestNeighbors for classification
6+
tasks and check that it is working properly.
7+
- a scikit-learn CV splitter where the splits are based on a Pandas
8+
DateTimeIndex.
9+
10+
Detailed instructions for question 1:
11+
The nearest neighbor classifier predicts for a point X_i the target y_k of
12+
the training sample X_k which is the closest to X_i. We measure proximity with
13+
the Euclidean distance. The model will be evaluated with the accuracy (average
14+
number of samples corectly classified). You need to implement the `fit`,
15+
`predict` and `score` methods for this class. The code you write should pass
16+
the test we implemented. You can run the tests by calling at the root of the
17+
repo `pytest test_sklearn_questions.py`. Note that to be fully valid, a
18+
scikit-learn estimator needs to check that the input given to `fit` and
19+
`predict` are correct using the `check_*` functions imported in the file.
20+
You can find more information on how they should be used in the following doc:
21+
https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator.
22+
Make sure to use them to pass `test_nearest_neighbor_check_estimator`.
23+
24+
25+
Detailed instructions for question 2:
26+
The data to split should contain the index or one column in
27+
datatime format. Then the aim is to split the data between train and test
28+
sets when for each pair of successive months, we learn on the first and
29+
predict of the following. For example if you have data distributed from
30+
november 2020 to march 2021, you have have 4 splits. The first split
31+
will allow to learn on november data and predict on december data, the
32+
second split to learn december and predict on january etc.
33+
34+
We also ask you to respect the pep8 convention: https://pep8.org. This will be
35+
enforced with `flake8`. You can check that there is no flake8 errors by
36+
calling `flake8` at the root of the repo.
37+
38+
Finally, you need to write docstrings for the methods you code and for the
39+
class. The docstring will be checked using `pydocstyle` that you can also
40+
call at the root of the repo.
41+
42+
Hints
43+
-----
44+
- You can use the function:
45+
46+
from sklearn.metrics.pairwise import pairwise_distances
47+
48+
to compute distances between 2 sets of samples.
49+
"""
350
import numpy as np
451
import pandas as pd
552

6-
import pandas.api.types as pdtypes
7-
853
from sklearn.base import BaseEstimator
954
from sklearn.base import ClassifierMixin
1055

1156
from sklearn.model_selection import BaseCrossValidator
12-
from sklearn.utils.multiclass import unique_labels
13-
from sklearn.utils.validation import validate_data, check_is_fitted
1457

15-
from collections import Counter
58+
from sklearn.utils.validation import (check_X_y, check_is_fitted,
59+
validate_data)
60+
from sklearn.utils.multiclass import check_classification_targets
61+
from sklearn.metrics.pairwise import pairwise_distances
1662

1763

1864
class KNearestNeighbors(ClassifierMixin, BaseEstimator):
65+
"""KNearestNeighbors classifier.
1966
20-
"""KNearestNeighbors classifier."""
67+
This class implements a K-Nearest Neighbors classifier for classification
68+
tasks. The classifier predicts the label of a test point based on the
69+
majority class of its nearest neighbors in the training dataset.
2170
22-
def __init__(self, num_neighbors=1): # noqa: D107
23-
self.num_neighbors = num_neighbors
71+
Parameters
72+
----------
73+
n_neighbors : int, default=1
74+
Number of neighbors to use for classification.
75+
"""
2476

25-
def fit(self, features, labels):
77+
def __init__(self, n_neighbors=1): # noqa: D107
78+
"""Initialize the classifier with the specified number of neighbors."""
79+
self.n_neighbors = n_neighbors
80+
81+
def fit(self, X, y):
2682
"""Fitting function.
2783
84+
This method stores the training data and labels for later use
85+
during prediction.
86+
2887
Parameters
2988
----------
30-
features : ndarray, shape (n_samples, n_features)
89+
X : ndarray, shape (n_samples, n_features)
3190
Data to train the model.
32-
labels : ndarray, shape (n_samples,)
91+
y : ndarray, shape (n_samples,)
3392
Labels associated with the training data.
3493
3594
Returns
3695
----------
3796
self : instance of KNearestNeighbors
38-
The current instance of the classifier
97+
The fitted instance of the classifier
3998
"""
40-
41-
(features, labels) = validate_data(self, features, labels)
42-
self.classes_ = unique_labels(labels)
43-
self.training_features_ = features
44-
self.training_labels_ = labels
99+
X, y = check_X_y(X, y)
100+
self.X_train_ = X
101+
self.y_train_ = y
102+
self.n_features_in_ = X.shape[1]
103+
check_classification_targets(y)
104+
self.classes_ = np.unique(y)
45105
return self
46106

47-
def predict(self, features):
107+
def predict(self, X):
48108
"""Predict function.
49109
50110
Parameters
51111
----------
52-
features : ndarray, shape (n_test_samples, n_features)
112+
X : ndarray, shape (n_test_samples, n_features)
53113
Data to predict on.
54114
55115
Returns
56116
----------
57-
predictions : ndarray, shape (n_test_samples,)
117+
y : ndarray, shape (n_test_samples,)
58118
Predicted class labels for each test data sample.
59119
"""
60-
61120
check_is_fitted(self)
62-
features = validate_data(self, features, reset=False)
63-
64-
predictions = np.full(features.shape[0], self.training_labels_[0])
65-
for idx in range(features.shape[0]):
66-
feature = features[idx]
67-
neighbor_labels = []
68-
69-
distances = np.sum(
70-
(self.training_features_ - feature) ** 2, axis=1
71-
)
72-
nearest_indices = np.argpartition(
73-
distances, self.num_neighbors
74-
)[: self.num_neighbors]
75-
for neighbor_idx in nearest_indices:
76-
neighbor_labels += [self.training_labels_[neighbor_idx]]
77-
78-
predictions[idx] = Counter(neighbor_labels).most_common(1)[0][0]
79-
return predictions
80-
81-
def score(self, features, labels):
121+
X = validate_data(self, X, reset=False)
122+
distances = pairwise_distances(X, self.X_train_)
123+
nearest_neighbors = np.argsort(distances, axis=1)[:, :self.n_neighbors]
124+
unique_classes, y_indices = np.unique(self.y_train_,
125+
return_inverse=True)
126+
neighbor_labels = y_indices[nearest_neighbors]
127+
y_pred = np.array([unique_classes[np.bincount(labels).argmax()]
128+
for labels in neighbor_labels])
129+
return y_pred
130+
131+
def score(self, X, y):
82132
"""Calculate the score of the prediction.
83133
84134
Parameters
85135
----------
86-
features : ndarray, shape (n_samples, n_features)
136+
X : ndarray, shape (n_samples, n_features)
87137
Data to score on.
88-
labels : ndarray, shape (n_samples,)
89-
Target values.
138+
y : ndarray, shape (n_samples,)
139+
target values.
90140
91141
Returns
92142
----------
93-
accuracy : float
94-
Accuracy of the model computed for the (features, labels) pairs.
143+
score : float
144+
Accuracy of the model computed as the
145+
mean for correctly predicted labels.
95146
"""
96-
97-
predictions = self.predict(features)
98-
correct_predictions = 0
99-
for idx in range(features.shape[0]):
100-
if labels[idx] == predictions[idx]:
101-
correct_predictions += 1
102-
return correct_predictions / features.shape[0]
147+
y_pred = self.predict(X)
148+
return np.mean(y_pred == y)
103149

104150

105151
class MonthlySplit(BaseCrossValidator):
106-
107152
"""CrossValidator based on monthly split.
108153
109-
Split data based on the given `time_column` (or default to index).
110-
Each split corresponds to one month of data for the training
111-
and the next month of data for the test.
154+
Split data based on the given `time_col` (or default to index). Each split
155+
corresponds to one month of data for the training and the next month of
156+
data for the test.
112157
113158
Parameters
114159
----------
115-
time_column : str, defaults to 'index'
160+
time_col : str, defaults to 'index'
116161
Column of the input DataFrame that will be used to split the data. This
117162
column should be of type datetime. If split is called with a DataFrame
118163
for which this column is not a datetime, it will raise a ValueError.
119-
To use the index as column just set `time_column` to `'index'`.
164+
To use the index as column just set `time_col` to `'index'`.
120165
"""
121166

122-
def __init__(self, time_column="index"): # noqa: D107
123-
self.time_column = time_column
167+
def __init__(self, time_col='index'): # noqa: D107
168+
self.time_col = time_col
124169

125-
def get_n_splits(
126-
self,
127-
data,
128-
labels=None,
129-
groups=None,
130-
):
170+
def get_n_splits(self, X, y=None, groups=None):
131171
"""Return the number of splitting iterations in the cross-validator.
132172
133173
Parameters
134174
----------
135-
data : array-like of shape (n_samples, n_features)
175+
X : array-like of shape (n_samples, n_features)
136176
Training data, where `n_samples` is the number of samples
137177
and `n_features` is the number of features.
138-
labels : array-like of shape (n_samples,)
178+
y : array-like of shape (n_samples,)
139179
Always ignored, exists for compatibility.
140180
groups : array-like of shape (n_samples,)
141181
Always ignored, exists for compatibility.
142182
143183
Returns
144184
-------
145-
num_splits : int
146-
The number of splits.
185+
n_splits : int
186+
The number of splits based on unique months in the data.
147187
"""
148-
149-
if self.time_column == "index":
150-
if not isinstance(data.index, pd.DatetimeIndex):
151-
raise ValueError("datetime")
152-
sorted_data = data.sort_index()
153-
months = sorted_data.index.month
188+
if isinstance(X, pd.Series):
189+
times = X.index
190+
elif isinstance(X, pd.DataFrame):
191+
times = X.index if self.time_col == 'index' else X[self.time_col]
154192
else:
193+
raise ValueError("X should be a pandas DataFrame or Series.")
155194

156-
if not pdtypes.is_datetime64_dtype(data[self.time_column]):
157-
raise ValueError("datetime")
158-
sorted_data = data.sort_values(by=self.time_column)
159-
sorted_data.index = sorted_data[self.time_column]
160-
months = sorted_data.index.month
161-
162-
num_splits = 0
163-
for idx in range(1, len(months)):
164-
if months[idx] != months[idx - 1]:
165-
num_splits += 1
166-
return num_splits
167-
168-
def split(
169-
self,
170-
data,
171-
labels,
172-
groups=None,
173-
):
195+
if not pd.api.types.is_datetime64_any_dtype(times):
196+
raise ValueError("time_col must be a datetime column.")
197+
periods = pd.Series(times).dt.to_period("M")
198+
return len(periods.unique()) - 1
199+
200+
def split(self, X, y, groups=None):
174201
"""Generate indices to split data into training and test set.
175202
176203
Parameters
177204
----------
178-
data : array-like of shape (n_samples, n_features)
205+
X : array-like of shape (n_samples, n_features)
179206
Training data, where `n_samples` is the number of samples
180207
and `n_features` is the number of features.
181-
labels : array-like of shape (n_samples,)
208+
y : array-like of shape (n_samples,)
182209
Always ignored, exists for compatibility.
183210
groups : array-like of shape (n_samples,)
184211
Always ignored, exists for compatibility.
185212
186213
Yields
187214
------
188-
train_indices : ndarray
215+
idx_train : ndarray
189216
The training set indices for that split.
190-
test_indices : ndarray
217+
idx_test : ndarray
191218
The testing set indices for that split.
192219
"""
220+
# Determine time column
221+
if isinstance(X, pd.DataFrame):
222+
if self.time_col == 'index':
223+
times = X.index
224+
else:
225+
times = X[self.time_col]
226+
elif isinstance(X, pd.Series):
227+
times = X.index
228+
else:
229+
raise ValueError("X should be a pandas DataFrame or Series.")
230+
231+
# Ensure time column is datetime
232+
if not pd.api.types.is_datetime64_any_dtype(times):
233+
raise ValueError("time_col must be a datetime column.")
193234

194-
num_splits = self.get_n_splits(data, labels, groups)
235+
# Create a copy of the data
236+
X_copy = X.copy()
237+
y_copy = y.copy() if y is not None else None
195238

196-
if self.time_column == "index":
197-
months_list = [sorted(data.index)[0]]
239+
# Sort the copy of the data by time
240+
if isinstance(X_copy, pd.DataFrame) and self.time_col != 'index':
241+
sorted_data = X_copy.sort_values(by=self.time_col)
198242
else:
243+
sorted_data = X_copy.sort_index()
199244

200-
months_list = [sorted(data["date"])[0]]
245+
# Extract the sorted indices
246+
sorted_indices = sorted_data.index
201247

202-
for _ in range(num_splits):
203-
months_list += [months_list[-1] + pd.DateOffset(months=1)]
248+
# Map sorted indices to original indices
249+
times = pd.Series(times.values, index=sorted_indices).sort_index()
204250

205-
for split_idx in range(num_splits):
206-
train_month = months_list[split_idx]
207-
test_month = months_list[split_idx + 1]
208-
train_indices = []
209-
test_indices = []
251+
# Sort y_copy if it exists
252+
if y_copy is not None:
253+
y_copy = y_copy.loc[sorted_indices]
210254

211-
for data_idx in range(len(data)):
212-
if self.time_column == "index":
213-
current_date = data.index[data_idx]
214-
else:
215-
current_date = data.iloc[data_idx]["date"]
255+
# Group by unique months
256+
periods = times.dt.to_period("M")
257+
unique_periods = sorted(periods.unique())
216258

217-
if (
218-
current_date.month == train_month.month
219-
and current_date.year == train_month.year
220-
):
221-
train_indices.append(data_idx)
222-
elif (
223-
current_date.month == test_month.month
224-
and current_date.year == test_month.year
225-
):
259+
n_splits = self.get_n_splits(X_copy, y_copy, groups)
226260

227-
test_indices.append(data_idx)
261+
for i in range(n_splits):
262+
idx_train = np.where(periods == unique_periods[i])[0]
263+
idx_test = np.where(periods == unique_periods[i + 1])[0]
228264

229-
yield (train_indices, test_indices)
265+
yield idx_train, idx_test

0 commit comments

Comments
 (0)