Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/periodic_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ jobs:
- name: Disable Numba JIT
run: echo "NUMBA_DISABLE_JIT=1" >> $GITHUB_ENV

- name: Install aeon and dependencies
- name: Install
uses: nick-fields/retry@v3
with:
timeout_minutes: 30
Expand All @@ -101,7 +101,7 @@ jobs:
run: python -m pip list

- name: Run tests
run: python -m pytest -n logical --cov=aeon --cov-report=xml --timeout 1800
run: python -m pytest -n logical --cov=tsml --cov-report=xml --timeout 1800

- uses: codecov/codecov-action@v5
env:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
args: [ "--create", "--python-folders", "tsml" ]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.10
rev: v0.12.12
hooks:
- id: ruff
args: [ "--fix" ]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ all_extras = [
"grailts",
"scikit-fda>=0.7.0; python_version > '3.9' and python_version < '3.13'",
"statsmodels>=0.12.1",
"wildboar",
"wildboar<=1.2.0",
]
unstable_extras = [
"mrsqm>=0.0.7; platform_system == 'Linux' and python_version < '3.12'", # requires gcc and fftw to be installed for Windows and some other OS (see http://www.fftw.org/index.html)
Expand Down
3 changes: 3 additions & 0 deletions tsml/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def _check_n_features(self, X: np.ndarray | list[np.ndarray], reset: bool):
def _more_tags(self) -> dict:
return _DEFAULT_TAGS

def _get_tags(self) -> dict:
return _safe_tags(self)

@classmethod
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
"""Return unit test parameter settings for the estimator.
Expand Down
8 changes: 8 additions & 0 deletions tsml/compose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
__all__ = [
"ChannelEnsembleClassifier",
"ChannelEnsembleRegressor",
"SklearnToTsmlClassifier",
"SklearnToTsmlClusterer",
"SklearnToTsmlRegressor",
]

from tsml.compose._channel_ensemble import (
ChannelEnsembleClassifier,
ChannelEnsembleRegressor,
)
from tsml.compose._sklearn_to_tsml import (
SklearnToTsmlClassifier,
SklearnToTsmlClusterer,
SklearnToTsmlRegressor,
)
288 changes: 288 additions & 0 deletions tsml/compose/_sklearn_to_tsml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
"""A tsml wrapper for sklearn classifiers."""

__maintainer__ = ["MatthewMiddlehurst"]
__all__ = [
"SklearnToTsmlClassifier",
"SklearnToTsmlClusterer",
"SklearnToTsmlRegressor",
]

import numpy as np
from aeon.base._base import _clone_estimator
from sklearn.base import ClassifierMixin, ClusterMixin, RegressorMixin
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_is_fitted

from tsml.base import BaseTimeSeriesEstimator


class SklearnToTsmlClassifier(ClassifierMixin, BaseTimeSeriesEstimator):
"""Wrapper for sklearn estimators to use the tsml base class."""

def __init__(
self,
classifier=None,
pad_unequal=False,
concatenate_channels=False,
clone_estimator=True,
random_state=None,
):
self.classifier = classifier
self.pad_unequal = pad_unequal
self.concatenate_channels = concatenate_channels
self.clone_estimator = clone_estimator
self.random_state = random_state

super().__init__()

def fit(self, X, y):
"""Wrap fit."""
if self.classifier is None:
raise ValueError("Classifier not set")

X, y = self._validate_data(
X=X,
y=y,
ensure_univariate=not self.concatenate_channels,
ensure_equal_length=not self.pad_unequal,
)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

check_classification_targets(y)
self.classes_ = np.unique(y)

self._classifier = (
_clone_estimator(self.classifier, self.random_state)
if self.clone_estimator
else self.classifier
)
self._classifier.fit(X, y)

return self

def predict(self, X) -> np.ndarray:
"""Wrap predict."""
check_is_fitted(self)

X = self._validate_data(X=X, reset=False)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

return self._classifier.predict(X)

def predict_proba(self, X) -> np.ndarray:
"""Wrap predict_proba."""
check_is_fitted(self)

X = self._validate_data(X=X, reset=False)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

return self._classifier.predict_proba(X)

def _more_tags(self):
return {
"X_types": ["2darray"],
"equal_length_only": (False if self.pad_unequal else True),
"univariate_only": False if self.concatenate_channels else True,
}

@classmethod
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
"""Return unit test parameter settings for the estimator.

Parameters
----------
parameter_set : None or str, default=None
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.

Returns
-------
params : dict or list of dict
Parameters to create testing instances of the class.
"""
from sklearn.ensemble import RandomForestClassifier

return {"classifier": RandomForestClassifier(n_estimators=5)}


class SklearnToTsmlClusterer(ClusterMixin, BaseTimeSeriesEstimator):
"""Wrapper for sklearn estimators to use the tsml base class."""

def __init__(
self,
clusterer=None,
pad_unequal=False,
concatenate_channels=False,
clone_estimator=True,
random_state=None,
):
self.clusterer = clusterer
self.pad_unequal = pad_unequal
self.concatenate_channels = concatenate_channels
self.clone_estimator = clone_estimator
self.random_state = random_state

super().__init__()

def fit(self, X, y=None):
"""Wrap fit."""
if self.clusterer is None:
raise ValueError("Clusterer not set")

X = self._validate_data(
X=X,
ensure_univariate=not self.concatenate_channels,
ensure_equal_length=not self.pad_unequal,
)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

self._clusterer = (
_clone_estimator(self.clusterer, self.random_state)
if self.clone_estimator
else self.clusterer
)
self._clusterer.fit(X, y)

self.labels_ = self._clusterer.labels_

return self

def predict(self, X) -> np.ndarray:
"""Wrap predict."""
check_is_fitted(self)

X = self._validate_data(X=X, reset=False)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

return self._clusterer.predict(X)

def _more_tags(self):
return {
"X_types": ["2darray"],
"equal_length_only": (False if self.pad_unequal else True),
"univariate_only": False if self.concatenate_channels else True,
}

@classmethod
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
"""Return unit test parameter settings for the estimator.

Parameters
----------
parameter_set : None or str, default=None
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.

Returns
-------
params : dict or list of dict
Parameters to create testing instances of the class.
"""
from sklearn.cluster import KMeans

return {"clusterer": KMeans(n_clusters=2, max_iter=5)}


class SklearnToTsmlRegressor(RegressorMixin, BaseTimeSeriesEstimator):
"""Wrapper for sklearn estimators to use the tsml base class."""

def __init__(
self,
regressor=None,
pad_unequal=False,
concatenate_channels=False,
clone_estimator=True,
random_state=None,
):
self.regressor = regressor
self.pad_unequal = pad_unequal
self.concatenate_channels = concatenate_channels
self.clone_estimator = clone_estimator
self.random_state = random_state

super().__init__()

def fit(self, X, y):
"""Wrap fit."""
if self.regressor is None:
raise ValueError("Regressor not set")

X, y = self._validate_data(
X=X,
y=y,
ensure_univariate=not self.concatenate_channels,
ensure_equal_length=not self.pad_unequal,
)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

self._regressor = (
_clone_estimator(self.regressor, self.random_state)
if self.clone_estimator
else self.regressor
)
self._regressor.fit(X, y)

return self

def predict(self, X) -> np.ndarray:
"""Wrap predict."""
check_is_fitted(self)

X = self._validate_data(X=X, reset=False)
X = self._convert_X(
X,
pad_unequal=self.pad_unequal,
concatenate_channels=self.concatenate_channels,
)

return self._regressor.predict(X)

def _more_tags(self):
return {
"X_types": ["2darray"],
"equal_length_only": (False if self.pad_unequal else True),
"univariate_only": False if self.concatenate_channels else True,
}

@classmethod
def get_test_params(cls, parameter_set: str | None = None) -> dict | list[dict]:
"""Return unit test parameter settings for the estimator.

Parameters
----------
parameter_set : None or str, default=None
Name of the set of test parameters to return, for use in tests. If no
special parameters are defined for a value, will return `"default"` set.

Returns
-------
params : dict or list of dict
Parameters to create testing instances of the class.
"""
from sklearn.ensemble import RandomForestRegressor

return {"regressor": RandomForestRegressor(n_estimators=5)}
3 changes: 1 addition & 2 deletions tsml/tests/test_estimators_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import scale
from sklearn.utils._tags import _safe_tags as _safe_tags_sklearn
from sklearn.utils._testing import (
SkipTest,
assert_allclose,
Expand Down Expand Up @@ -1410,7 +1409,7 @@ def check_estimator_get_tags_default_keys(name, estimator_orig):
if not hasattr(estimator, "_get_tags"):
return

default_tags_keys = set(_safe_tags_sklearn(estimator).keys())
default_tags_keys = set(_safe_tags(estimator).keys())
tags_keys = set(estimator._get_tags().keys())
assert tags_keys.intersection(default_tags_keys) == default_tags_keys, (
f"{name}._get_tags() is missing entries for the following default tags: "
Expand Down
2 changes: 1 addition & 1 deletion tsml/tests/test_sklearn_compatability.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Unit tests for aeon classifier compatability with sklearn interfaces."""
"""Unit tests for tsml classifier compatability with sklearn interfaces."""

__maintainer__ = []
__all__ = [
Expand Down