diff --git a/skore-hub-project/pyproject.toml b/skore-hub-project/pyproject.toml index f6be51e8d..42586c471 100644 --- a/skore-hub-project/pyproject.toml +++ b/skore-hub-project/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "numpy", "orjson", "pydantic", - "rich", + "rich>=14.2.0", "scikit-learn", ] diff --git a/skore-hub-project/src/skore_hub_project/artifact/artifact.py b/skore-hub-project/src/skore_hub_project/artifact/artifact.py index 3d3ea3ac2..9f7f1ac38 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/artifact.py +++ b/skore-hub-project/src/skore_hub_project/artifact/artifact.py @@ -1,15 +1,20 @@ """Interface definition of the payload used to associate an artifact with a project.""" +from __future__ import annotations + from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor from contextlib import AbstractContextManager, nullcontext -from functools import cached_property +from typing import ClassVar from pydantic import BaseModel, ConfigDict, Field, computed_field from skore_hub_project import Project -from skore_hub_project.artifact.upload import upload +from skore_hub_project.artifact.serializer import Serializer, TxtSerializer +from skore_hub_project.artifact.upload import upload as upload_content +from skore_hub_project.protocol import CrossValidationReport, EstimatorReport -Content = str | bytes | None +Content = EstimatorReport | CrossValidationReport | str | bytes | None class Artifact(BaseModel, ABC): @@ -29,47 +34,54 @@ class Artifact(BaseModel, ABC): as a file to the ``hub`` artifacts storage. """ - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True) project: Project = Field(repr=False, exclude=True) + serializer_cls: ClassVar[type[Serializer]] = TxtSerializer content_type: str = Field(init=False) + @property @abstractmethod - def content_to_upload(self) -> Content | AbstractContextManager[Content]: - """ - Content of the artifact to upload. - - Example - ------- - You can implement this ``abstractmethod`` to return directly the content: - - def content_to_upload(self) -> str: - return "" - - or to yield the content, as a ``contextmanager`` would: - - from contextlib import contextmanager - - @contextmanager - def content_to_upload(self) -> Generator[str, None, None]: - yield "" - """ + def content_to_upload(self) -> bytes | None: + """Content of the artifact to upload.""" @computed_field # type: ignore[prop-decorator] - @cached_property + @property + @abstractmethod def checksum(self) -> str | None: """Checksum used to identify the content of the artifact.""" + + def upload( + self, + *, + pool: ThreadPoolExecutor | None = None, + checksums_being_uploaded: set[str] | None = None, + ) -> None: + """Upload the artifact and set its checksum.""" contextmanager = self.content_to_upload() + checksums_being_uploaded = ( + set() if checksums_being_uploaded is None else checksums_being_uploaded + ) if not isinstance(contextmanager, AbstractContextManager): contextmanager = nullcontext(contextmanager) with contextmanager as content: - if content is not None: - return upload( - project=self.project, - content=content, - content_type=self.content_type, - ) - - return None + if content is None: + self.checksum = None + return + + with self.serializer_cls(content) as serializer: + if serializer.checksum not in checksums_being_uploaded: + checksums_being_uploaded.add(serializer.checksum) + upload_content( + project=self.project, + serializer=serializer, + content=content, + content_type=self.content_type, + pool=( + ThreadPoolExecutor(max_workers=6) if pool is None else pool + ), + ) + + self.checksum = serializer.checksum diff --git a/skore-hub-project/src/skore_hub_project/artifact/media/data.py b/skore-hub-project/src/skore_hub_project/artifact/media/data.py index 11b5b14e0..5cfc2887f 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/media/data.py +++ b/skore-hub-project/src/skore_hub_project/artifact/media/data.py @@ -1,5 +1,6 @@ """Definition of the payload used to associate a data category media with report.""" +from functools import cached_property from typing import Literal from skore_hub_project import switch_mpl_backend @@ -14,6 +15,7 @@ class TableReport(Media[Report]): # noqa: D101 "application/vnd.skrub.table-report.v1+json" ) + @cached_property def content_to_upload(self) -> bytes: # noqa: D102 import orjson diff --git a/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py b/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py index 80792ddc8..c61057c52 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py +++ b/skore-hub-project/src/skore_hub_project/artifact/media/feature_importance.py @@ -20,6 +20,7 @@ class FeatureImportance(Media[Report], ABC): # noqa: D101 accessor: ClassVar[str] content_type: Literal["application/vnd.dataframe"] = "application/vnd.dataframe" + @property def content_to_upload(self) -> bytes | None: # noqa: D102 try: function = cast( @@ -44,8 +45,9 @@ class Permutation(FeatureImportance[EstimatorReport], ABC): # noqa: D101 accessor: ClassVar[str] = "feature_importance.permutation" name: Literal["permutation"] = "permutation" + @property def content_to_upload(self) -> bytes | None: # noqa: D102 - for key, obj in reversed(self.report._cache.items()): + for key, obj in reversed(list(self.report._cache.items())): if len(key) < 7: continue diff --git a/skore-hub-project/src/skore_hub_project/artifact/media/media.py b/skore-hub-project/src/skore_hub_project/artifact/media/media.py index 59764f81d..eea8e1443 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/media/media.py +++ b/skore-hub-project/src/skore_hub_project/artifact/media/media.py @@ -3,8 +3,10 @@ from abc import ABC from typing import Generic, TypeVar -from pydantic import Field +from blake3 import blake3 as Blake3 +from pydantic import Field, computed_field +from skore_hub_project import bytes_to_b64_str from skore_hub_project.artifact.artifact import Artifact from skore_hub_project.protocol import CrossValidationReport, EstimatorReport @@ -32,3 +34,28 @@ class Media(Artifact, ABC, Generic[Report]): report: Report = Field(repr=False, exclude=True) name: str = Field(init=False) data_source: str | None = Field(init=False) + + @computed_field # type: ignore[prop-decorator] + @property + def checksum(self) -> str | None: + """ + Checksum used to identify the content of the media. + + Notes + ----- + Depending on the size of the serialized content, the checksum can be computed on + one or more threads: + + Note that this can be slower for inputs shorter than ~1 MB + + https://github.com/oconnor663/blake3-py + """ + if self.content_to_upload is None: + return None + + # Compute checksum with the appropriate number of threads + threads = (1 if len(self.content_to_upload) < 1e6 else Blake3.AUTO) + hasher = Blake3(max_threads=threads) + checksum = hasher.update(self.content_to_upload).digest() + + return f"blake3-{bytes_to_b64_str(checksum)}" diff --git a/skore-hub-project/src/skore_hub_project/artifact/media/model.py b/skore-hub-project/src/skore_hub_project/artifact/media/model.py index 34e60ecf0..ffbea9994 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/media/model.py +++ b/skore-hub-project/src/skore_hub_project/artifact/media/model.py @@ -10,11 +10,12 @@ class EstimatorHtmlRepr(Media[Report]): # noqa: D101 data_source: None = None content_type: Literal["text/html"] = "text/html" - def content_to_upload(self) -> str: # noqa: D102 + @property + def content_to_upload(self) -> bytes: # noqa: D102 import sklearn.utils estimator_html_repr: str = sklearn.utils.estimator_html_repr( self.report.estimator ) - return estimator_html_repr + return estimator_html_repr.encode(encoding="utf-8") diff --git a/skore-hub-project/src/skore_hub_project/artifact/media/performance.py b/skore-hub-project/src/skore_hub_project/artifact/media/performance.py index 7dd14686c..22f209cd5 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/media/performance.py +++ b/skore-hub-project/src/skore_hub_project/artifact/media/performance.py @@ -19,6 +19,7 @@ class Performance(Media[Report], ABC): # noqa: D101 accessor: ClassVar[str] content_type: Literal["image/svg+xml"] = "image/svg+xml" + @property def content_to_upload(self) -> bytes | None: # noqa: D102 try: function = cast( diff --git a/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py b/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py index c53b86ce0..b7031871e 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py +++ b/skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py @@ -2,13 +2,12 @@ from collections.abc import Generator from contextlib import contextmanager -from io import BytesIO -from typing import Literal +from typing import ClassVar, Literal -from joblib import dump from pydantic import Field from skore_hub_project.artifact.artifact import Artifact +from skore_hub_project.artifact.serializer import ReportSerializer, Serializer from skore_hub_project.protocol import CrossValidationReport, EstimatorReport Report = EstimatorReport | CrossValidationReport @@ -35,11 +34,12 @@ class Pickle(Artifact): The report is primarily pickled on disk to reduce RAM footprint. """ - report: Report = Field(repr=False, exclude=True) + serializer_cls: ClassVar[type[Serializer]] = ReportSerializer content_type: Literal["application/octet-stream"] = "application/octet-stream" + report: Report = Field(repr=False, exclude=True) - @contextmanager - def content_to_upload(self) -> Generator[bytes, None, None]: + @cached_property + def content_to_upload(self) -> bytes: """ Content of the pickled report. @@ -56,9 +56,7 @@ def content_to_upload(self) -> Generator[bytes, None, None]: with BytesIO() as stream: dump(self.report, stream) - pickle_bytes = stream.getvalue() - - yield pickle_bytes + yield self.report finally: for report, cache in zip(reports, caches, strict=True): report._cache = cache diff --git a/skore-hub-project/src/skore_hub_project/artifact/serializer.py b/skore-hub-project/src/skore_hub_project/artifact/serializer.py index 3b011ae4e..cbb85723c 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/serializer.py +++ b/skore-hub-project/src/skore_hub_project/artifact/serializer.py @@ -2,22 +2,26 @@ from __future__ import annotations +from abc import ABC, abstractmethod from functools import cached_property +from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile from typing import Any from blake3 import blake3 as Blake3 +from joblib import dump +from skore_hub_project.protocol import CrossValidationReport, EstimatorReport -class Serializer: - """Serialize a content directly on disk to reduce RAM footprint.""" - def __init__(self, content: str | bytes): - if isinstance(content, str): - self.filepath.write_text(content, encoding="utf-8") - else: - self.filepath.write_bytes(content) +class Serializer(ABC): + """Abstract class to serialize anything on disk.""" + + called: bool = False + + @abstractmethod + def __init__(self, _: Any, /): ... def __enter__(self) -> Serializer: # noqa: D105 return self @@ -25,6 +29,10 @@ def __enter__(self) -> Serializer: # noqa: D105 def __exit__(self, *args: Any) -> None: # noqa: D105 self.filepath.unlink(True) + @abstractmethod + def __call__(self) -> None: + """Serialize anything on disk.""" + @cached_property def filepath(self) -> Path: """The filepath used to serialize the content.""" @@ -47,6 +55,13 @@ def checksum(self) -> str: """ from skore_hub_project import bytes_to_b64_str + if not self.called: + raise RuntimeError( + "You cannot access the checksum of a serializer without explicitly " + "called it. Please use `serializer()` before." + ) + + # Compute checksum with the appropriate number of threads hasher = Blake3(max_threads=(1 if self.size < 1e6 else Blake3.AUTO)) checksum = hasher.update_mmap(self.filepath).digest() @@ -55,4 +70,47 @@ def checksum(self) -> str: @cached_property def size(self) -> int: """The size of the serialized content, in bytes.""" + if not self.called: + raise RuntimeError( + "You cannot access the size of a serializer without explicitly " + "called it. Please use `serializer()` before." + ) + return self.filepath.stat().st_size + + +class TxtSerializer(Serializer): + """Serialize a str or bytes on disk.""" + + def __init__(self, txt: str | bytes, /): + if isinstance(txt, str): + txt = txt.encode(encoding="utf-8") + + self.filepath.write_bytes(txt) + self.called = True + + def __call__(self) -> None: + """Serialize a str or bytes on disk.""" + + +class ReportSerializer(Serializer): + """Serialize a report using joblib on disk.""" + + def __init__(self, report: CrossValidationReport | EstimatorReport, /): + self.report = report + + def __call__(self) -> None: + """Serialize a report using joblib on disk.""" + if self.called: + return + + with BytesIO() as stream: + dump(self.report, stream) + + self.filepath.write_bytes(stream.getvalue()) + self.called = True + + @cached_property + def checksum(self) -> str: + """The checksum of the serialized report.""" + return f"skore-{self.report.__class__.__name__}-{self.report._hash}" diff --git a/skore-hub-project/src/skore_hub_project/artifact/upload.py b/skore-hub-project/src/skore_hub_project/artifact/upload.py index 2f7fc7ba4..9275bd4f8 100644 --- a/skore-hub-project/src/skore_hub_project/artifact/upload.py +++ b/skore-hub-project/src/skore_hub_project/artifact/upload.py @@ -3,41 +3,24 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import partial from math import ceil from pathlib import Path from typing import TYPE_CHECKING -from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn - -from ..client.client import Client, HUBClient -from .serializer import Serializer +from skore_hub_project.client.client import Client, HUBClient if TYPE_CHECKING: - from typing import Final - - import httpx - - from ..project.project import Project + from typing import Any, Final + from httpx import Client as httpx_Client -SkinnedProgress = partial( - Progress, - TextColumn("[bold cyan blink]Uploading..."), - BarColumn( - complete_style="dark_orange", - finished_style="dark_orange", - pulse_style="orange1", - ), - TextColumn("[orange1]{task.percentage:>3.0f}%"), - TimeElapsedColumn(), - transient=True, -) + from skore_hub_project.artifact.serializer import Serializer + from skore_hub_project.project.project import Project def upload_chunk( filepath: Path, - client: httpx.Client, + client: httpx_Client, url: str, offset: int, length: int, @@ -89,7 +72,13 @@ def upload_chunk( CHUNK_SIZE: Final[int] = int(1e7) # ~10mb -def upload(project: Project, content: str | bytes, content_type: str) -> str: +def upload( + project: Project, + serializer: Serializer, + content: Any, + content_type: str, + pool: ThreadPoolExecutor, +) -> None: """ Upload content to the artifacts storage. @@ -97,27 +86,34 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str: ---------- project : ``Project`` The project where to upload the content. - content : str | bytes + serializer : Serializer + The serializer to use for the content serialization. + content : Any The content to upload. content_type : str The type of content to upload. - - Returns - ------- - checksum : str - The checksum of the content before upload to the artifacts storage, based on its - serialization. + pool : TheadPoolExecutor + The pool used to execute the `upload_chunk` threads. Notes ----- A content that was already uploaded in its whole will be ignored. """ - with ( - Serializer(content) as serializer, - HUBClient() as hub_client, - Client() as standard_client, - ThreadPoolExecutor() as pool, - ): + with HUBClient() as hub_client, Client() as standard_client: + # Ask for the artifact. + # + # An non-empty response means that an artifact with the same checksum already + # exists. The content doesn't have to be re-uploaded. + response = hub_client.get( + url=f"projects/{project.quoted_tenant}/{project.quoted_name}/artifacts", + params={"artifact_checksum": serializer.checksum, "status": "uploaded"}, + ) + + if response.json(): + return None + + serializer() + # Ask for upload urls. response = hub_client.post( url=f"projects/{project.quoted_tenant}/{project.quoted_name}/artifacts", @@ -130,64 +126,52 @@ def upload(project: Project, content: str | bytes, content_type: str) -> str: ], ) - # An empty response means that an artifact with the same checksum already - # exists. The content doesn't have to be re-uploaded. - if urls := response.json(): - task_to_chunk_id = {} - - # Upload each chunk of the serialized content to the artifacts storage, - # using a disk temporary file. - # - # Each task is in charge of reading its own file chunk at runtime, to reduce - # RAM footprint. - # - # Use `threading` over `asyncio` to ensure compatibility with Jupyter - # notebooks, where the event loop is already running. - for url in urls: - chunk_id = url["chunk_id"] or 1 - task = pool.submit( - upload_chunk, - filepath=serializer.filepath, - client=standard_client, - url=url["upload_url"], - offset=((chunk_id - 1) * CHUNK_SIZE), - length=CHUNK_SIZE, - content_type=( - content_type if len(urls) == 1 else "application/octet-stream" - ), - ) - - task_to_chunk_id[task] = chunk_id - - try: - with SkinnedProgress() as progress: - tasks = as_completed(task_to_chunk_id) - total = len(task_to_chunk_id) - etags = dict( - sorted( - ( - task_to_chunk_id[task], - task.result(), - ) - for task in progress.track(tasks, total=total) - ) - ) - except BaseException: - # Cancel all remaining tasks, especially on `KeyboardInterrupt`. - for task in task_to_chunk_id: - task.cancel() - - raise - - # Acknowledge the upload, to let the hub/storage rebuild the whole. - hub_client.post( - url=f"projects/{project.quoted_tenant}/{project.quoted_name}/artifacts/complete", - json=[ - { - "checksum": serializer.checksum, - "etags": etags, - } - ], + urls = response.json() + task_to_chunk_id = {} + + # Upload each chunk of the serialized content to the artifacts storage, + # using a disk temporary file. + # + # Each task is in charge of reading its own file chunk at runtime, to reduce + # RAM footprint. + # + # Use `threading` over `asyncio` to ensure compatibility with Jupyter + # notebooks, where the event loop is already running. + for url in urls: + chunk_id = url["chunk_id"] or 1 + task = pool.submit( + upload_chunk, + filepath=serializer.filepath, + client=standard_client, + url=url["upload_url"], + offset=((chunk_id - 1) * CHUNK_SIZE), + length=CHUNK_SIZE, + content_type=( + content_type if len(urls) == 1 else "application/octet-stream" + ), + ) + + task_to_chunk_id[task] = chunk_id + + try: + tasks = as_completed(task_to_chunk_id) + etags = dict( + sorted((task_to_chunk_id[task], task.result()) for task in tasks) ) + except BaseException: + # Cancel all remaining tasks, especially on `KeyboardInterrupt`. + for task in task_to_chunk_id: + task.cancel() + + raise - return serializer.checksum + # Acknowledge the upload, to let the hub/storage rebuild the whole. + hub_client.post( + url=f"projects/{project.quoted_tenant}/{project.quoted_name}/artifacts/complete", + json=[ + { + "checksum": serializer.checksum, + "etags": etags, + } + ], + ) diff --git a/skore-hub-project/src/skore_hub_project/metric/metric.py b/skore-hub-project/src/skore_hub_project/metric/metric.py index a90200771..296914d66 100644 --- a/skore-hub-project/src/skore_hub_project/metric/metric.py +++ b/skore-hub-project/src/skore_hub_project/metric/metric.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from contextlib import suppress -from functools import cached_property, reduce +from functools import reduce from math import isfinite from typing import ( TYPE_CHECKING, @@ -59,7 +59,7 @@ class Metric(BaseModel, ABC, Generic[Report]): default None to disable its display. """ - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True) report: Report = Field(repr=False, exclude=True) name: str = Field(init=False) @@ -70,9 +70,26 @@ class Metric(BaseModel, ABC, Generic[Report]): @computed_field # type: ignore[prop-decorator] @property - @abstractmethod def value(self) -> float | None: """The value of the metric.""" + try: + return self.__value + except AttributeError: + message = ( + "You cannot access the value of a metric " + "without explicitly calculating it. " + "Please use `metric.compute()` before." + ) + + raise RuntimeError(message) from None + + @value.setter + def value(self, value: float | None) -> None: + self.__value = value + + @abstractmethod + def compute(self) -> None: + """Compute the value of the metric.""" class EstimatorReportMetric(Metric[EstimatorReport]): @@ -100,19 +117,17 @@ class EstimatorReportMetric(Metric[EstimatorReport]): accessor: ClassVar[str] - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: - """The value of the metric.""" + def compute(self) -> None: + """Compute the value of the metric.""" try: function = cast( Callable[..., float | None], reduce(getattr, self.accessor.split("."), self.report), ) except AttributeError: - return None - - return cast_to_float(function(data_source=self.data_source)) + self.value = None + else: + self.value = cast_to_float(function(data_source=self.data_source)) class CrossValidationReportMetric(Metric[CrossValidationReport]): @@ -143,18 +158,15 @@ class CrossValidationReportMetric(Metric[CrossValidationReport]): accessor: ClassVar[str] aggregate: ClassVar[Literal["mean", "std"]] - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: - """The value of the metric.""" + def compute(self) -> None: + """Compute the value of the metric.""" try: function = cast( "Callable[..., DataFrame]", reduce(getattr, self.accessor.split("."), self.report), ) except AttributeError: - return None - - dataframe = function(data_source=self.data_source, aggregate=self.aggregate) - - return cast_to_float(dataframe.iloc[0, 0]) + self.value = None + else: + dataframe = function(data_source=self.data_source, aggregate=self.aggregate) + self.value = cast_to_float(dataframe.iloc[0, 0]) diff --git a/skore-hub-project/src/skore_hub_project/metric/precision.py b/skore-hub-project/src/skore_hub_project/metric/precision.py index ddee0d784..54272adbf 100644 --- a/skore-hub-project/src/skore_hub_project/metric/precision.py +++ b/skore-hub-project/src/skore_hub_project/metric/precision.py @@ -2,11 +2,8 @@ from __future__ import annotations -from functools import cached_property from typing import ClassVar, Literal -from pydantic import computed_field - from .metric import CrossValidationReportMetric, EstimatorReportMetric, cast_to_float @@ -17,15 +14,16 @@ class Precision(EstimatorReportMetric): # noqa: D101 greater_is_better: bool = True position: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" try: function = self.report.metrics.precision except AttributeError: - return None - - return cast_to_float(function(data_source=self.data_source, average="macro")) + self.value = None + else: + self.value = cast_to_float( + function(data_source=self.data_source, average="macro") + ) class PrecisionTrain(Precision): # noqa: D101 @@ -43,19 +41,17 @@ class PrecisionMean(CrossValidationReportMetric): # noqa: D101 greater_is_better: bool = True position: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" try: function = self.report.metrics.precision except AttributeError: - return None - - dataframe = function( - data_source=self.data_source, aggregate="mean", average="macro" - ) - - return cast_to_float(dataframe.iloc[0, 0]) + self.value = None + else: + dataframe = function( + data_source=self.data_source, aggregate="mean", average="macro" + ) + self.value = cast_to_float(dataframe.iloc[0, 0]) class PrecisionTrainMean(PrecisionMean): # noqa: D101 @@ -73,19 +69,17 @@ class PrecisionStd(CrossValidationReportMetric): # noqa: D101 greater_is_better: bool = False position: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" try: function = self.report.metrics.precision except AttributeError: - return None - - dataframe = function( - data_source=self.data_source, aggregate="std", average="macro" - ) - - return cast_to_float(dataframe.iloc[0, 0]) + self.value = None + else: + dataframe = function( + data_source=self.data_source, aggregate="std", average="macro" + ) + self.value = cast_to_float(dataframe.iloc[0, 0]) class PrecisionTrainStd(PrecisionStd): # noqa: D101 diff --git a/skore-hub-project/src/skore_hub_project/metric/recall.py b/skore-hub-project/src/skore_hub_project/metric/recall.py index fea638a39..caa62fc92 100644 --- a/skore-hub-project/src/skore_hub_project/metric/recall.py +++ b/skore-hub-project/src/skore_hub_project/metric/recall.py @@ -2,11 +2,8 @@ from __future__ import annotations -from functools import cached_property from typing import ClassVar, Literal -from pydantic import computed_field - from .metric import CrossValidationReportMetric, EstimatorReportMetric, cast_to_float @@ -17,15 +14,16 @@ class Recall(EstimatorReportMetric): # noqa: D101 greater_is_better: bool = True position: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" try: function = self.report.metrics.recall except AttributeError: - return None - - return cast_to_float(function(data_source=self.data_source, average="macro")) + self.value = None + else: + self.value = cast_to_float( + function(data_source=self.data_source, average="macro") + ) class RecallTrain(Recall): # noqa: D101 @@ -44,19 +42,17 @@ class RecallMean(CrossValidationReportMetric): # noqa: D101 greater_is_better: bool = True position: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" try: function = self.report.metrics.recall except AttributeError: - return None - - dataframe = function( - data_source=self.data_source, aggregate=self.aggregate, average="macro" - ) - - return cast_to_float(dataframe.iloc[0, 0]) + self.value = None + else: + dataframe = function( + data_source=self.data_source, aggregate=self.aggregate, average="macro" + ) + self.value = cast_to_float(dataframe.iloc[0, 0]) class RecallTrainMean(RecallMean): # noqa: D101 @@ -75,19 +71,17 @@ class RecallStd(CrossValidationReportMetric): # noqa: D101 greater_is_better: bool = False position: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" try: function = self.report.metrics.recall except AttributeError: - return None - - dataframe = function( - data_source=self.data_source, aggregate=self.aggregate, average="macro" - ) - - return cast_to_float(dataframe.iloc[0, 0]) + self.value = None + else: + dataframe = function( + data_source=self.data_source, aggregate=self.aggregate, average="macro" + ) + self.value = cast_to_float(dataframe.iloc[0, 0]) class RecallTrainStd(RecallStd): # noqa: D101 diff --git a/skore-hub-project/src/skore_hub_project/metric/timing.py b/skore-hub-project/src/skore_hub_project/metric/timing.py index 795ec2df7..a7baff580 100644 --- a/skore-hub-project/src/skore_hub_project/metric/timing.py +++ b/skore-hub-project/src/skore_hub_project/metric/timing.py @@ -2,11 +2,8 @@ from __future__ import annotations -from functools import cached_property from typing import ClassVar, Literal -from pydantic import computed_field - from skore_hub_project.protocol import CrossValidationReport, EstimatorReport from .metric import Metric, cast_to_float @@ -19,13 +16,12 @@ class FitTime(Metric[EstimatorReport]): # noqa: D101 position: int = 1 data_source: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" timings = self.report.metrics.timings() fit_time = timings.get("fit_time") - return cast_to_float(fit_time) + self.value = cast_to_float(fit_time) class FitTimeAggregate(Metric[CrossValidationReport]): # noqa: D101 @@ -40,17 +36,16 @@ class FitTimeAggregate(Metric[CrossValidationReport]): # noqa: D101 greater_is_better: bool = False data_source: None = None - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" timings = self.report.metrics.timings(aggregate=self.aggregate) try: fit_times = timings.loc["Fit time (s)"] except KeyError: - return None - - return cast_to_float(fit_times.iloc[0]) + self.value = None + else: + self.value = cast_to_float(fit_times.iloc[0]) class FitTimeMean(FitTimeAggregate): # noqa: D101 @@ -73,13 +68,12 @@ class PredictTime(Metric[EstimatorReport]): # noqa: D101 greater_is_better: bool = False position: int = 2 - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" timings = self.report.metrics.timings() predict_time = timings.get(f"predict_time_{self.data_source}") - return cast_to_float(predict_time) + self.value = cast_to_float(predict_time) class PredictTimeTrain(PredictTime): # noqa: D101 @@ -101,17 +95,16 @@ class PredictTimeAggregate(Metric[CrossValidationReport]): # noqa: D101 aggregate: ClassVar[Literal["mean", "std"]] greater_is_better: bool = False - @computed_field # type: ignore[prop-decorator] - @cached_property - def value(self) -> float | None: # noqa: D102 + def compute(self) -> None: + """Compute the value of the metric.""" timings = self.report.metrics.timings(aggregate=self.aggregate) try: predict_times = timings.loc[f"Predict time {self.data_source} (s)"] except KeyError: - return None - - return cast_to_float(predict_times.iloc[0]) + self.value = None + else: + self.value = cast_to_float(predict_times.iloc[0]) class PredictTimeMean(PredictTimeAggregate): # noqa: D101 diff --git a/skore-hub-project/src/skore_hub_project/project/project.py b/skore-hub-project/src/skore_hub_project/project/project.py index 93d57d88d..a98eed3dd 100644 --- a/skore-hub-project/src/skore_hub_project/project/project.py +++ b/skore-hub-project/src/skore_hub_project/project/project.py @@ -22,6 +22,7 @@ import joblib import orjson from httpx import HTTPStatusError +from rich.progress import Progress, TextColumn, TimeElapsedColumn from skore_hub_project.client.client import Client, HUBClient from skore_hub_project.protocol import CrossValidationReport, EstimatorReport @@ -200,19 +201,34 @@ def put(self, key: str, report: EstimatorReport | CrossValidationReport) -> None f"`skore.CrossValidationReport` (found '{type(report)}')" ) - payload_dict = payload.model_dump() - payload_json_bytes = orjson.dumps(payload_dict, option=orjson.OPT_NON_STR_KEYS) + progress = Progress( + TextColumn("{task.description}"), + TimeElapsedColumn(), + transient=True, + ) - with HUBClient() as hub_client: - hub_client.post( - url=f"projects/{self.quoted_tenant}/{self.quoted_name}/{endpoint}", - content=payload_json_bytes, - headers={ - "Content-Length": str(len(payload_json_bytes)), - "Content-Type": "application/json", - }, + with progress: + task = progress.add_task( + description=f"[bold red1]Putting [bright_white on red1 blink]{key}" + ) + + payload_dict = payload.model_dump() + payload_json_bytes = orjson.dumps( + payload_dict, option=orjson.OPT_NON_STR_KEYS ) + with HUBClient() as hub_client: + hub_client.post( + url=f"projects/{self.quoted_tenant}/{self.quoted_name}/{endpoint}", + content=payload_json_bytes, + headers={ + "Content-Length": str(len(payload_json_bytes)), + "Content-Type": "application/json", + }, + ) + + progress.update(task, completed=1) + @ensure_project_is_created def get(self, urn: str) -> EstimatorReport | CrossValidationReport: """Get a persisted report by its URN.""" diff --git a/skore-hub-project/src/skore_hub_project/report/report.py b/skore-hub-project/src/skore_hub_project/report/report.py index 1ee1deba0..0246ac1de 100644 --- a/skore-hub-project/src/skore_hub_project/report/report.py +++ b/skore-hub-project/src/skore_hub_project/report/report.py @@ -3,17 +3,34 @@ from __future__ import annotations from abc import ABC -from functools import cached_property +from collections import deque +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import cached_property, partial +from threading import RLock from typing import ClassVar, Generic, TypeVar, cast from pydantic import BaseModel, ConfigDict, Field, computed_field +from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn -from skore_hub_project import Project +from skore_hub_project import Project, switch_mpl_backend from skore_hub_project.artifact.media.media import Media from skore_hub_project.artifact.pickle import Pickle from skore_hub_project.metric.metric import Metric from skore_hub_project.protocol import CrossValidationReport, EstimatorReport +SkinnedProgress = partial( + Progress, + TextColumn("[bold cyan]{task.description}..."), + BarColumn( + complete_style="dark_orange", + finished_style="dark_orange", + pulse_style="orange1", + ), + TextColumn("[orange1]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + transient=True, +) + Report = TypeVar("Report", bound=(EstimatorReport | CrossValidationReport)) @@ -89,15 +106,27 @@ def metrics(self) -> list[Metric[Report]]: - int [0, inf[, to be displayed at the position, - None, not to be displayed. """ - payloads = [] - - for metric_cls in self.METRICS: - payload = metric_cls(report=self.report) - - if payload.value is not None: - payloads.append(payload) - - return payloads + metrics = [metric_cls(report=self.report) for metric_cls in self.METRICS] + + with ( + switch_mpl_backend(), + SkinnedProgress() as progress, + ThreadPoolExecutor() as pool, + ): + tasks = [ + pool.submit(lambda metric: metric.compute(), metric) + for metric in metrics + ] + + deque( + progress.track( + as_completed(tasks), + description=f"Computing {self.report.__class__.__name__} metrics", + total=len(tasks), + ) + ) + + return [metric for metric in metrics if metric.value is not None] @computed_field # type: ignore[prop-decorator] @cached_property @@ -111,15 +140,47 @@ def medias(self) -> list[Media[Report]]: ----- Unavailable medias have been filtered out. """ - payloads = [] - for media_cls in self.MEDIAS: - payload = media_cls(project=self.project, report=self.report) - - if payload.checksum is not None: - payloads.append(payload) - - return payloads + class ThreadSafeSet(set[str]): + def __init__(self) -> None: + self.__lock = RLock() + + def add(self, item: str) -> None: + with self.__lock: + super().add(item) + + checksums_being_uploaded = ThreadSafeSet() + medias = [ + media_cls(project=self.project, report=self.report) + for media_cls in self.MEDIAS + ] + + with ( + switch_mpl_backend(), + SkinnedProgress() as progress, + ThreadPoolExecutor() as compute_pool, + ThreadPoolExecutor(max_workers=6) as upload_pool, + ): + tasks = [ + compute_pool.submit( + lambda media: media.upload( + pool=upload_pool, + checksums_being_uploaded=checksums_being_uploaded, + ), + media, + ) + for media in medias + ] + + deque( + progress.track( + as_completed(tasks), + description=f"Uploading {self.report.__class__.__name__} media", + total=len(tasks), + ) + ) + + return [media for media in medias if media.checksum is not None] @computed_field # type: ignore[prop-decorator] @cached_property @@ -131,4 +192,7 @@ def pickle(self) -> Pickle: artifact storage. It is based on its ``joblib`` serialization and mainly used to retrieve it from the artifacts storage. """ - return Pickle(project=self.project, report=self.report) + pickle = Pickle(project=self.project, report=self.report) + pickle.upload() + + return pickle diff --git a/skore-hub-project/tests/conftest.py b/skore-hub-project/tests/conftest.py index 83a2b1949..436f68b69 100644 --- a/skore-hub-project/tests/conftest.py +++ b/skore-hub-project/tests/conftest.py @@ -20,11 +20,14 @@ def upload_mock(): @fixture def monkeypatch_upload_with_mock(monkeypatch, upload_mock): - monkeypatch.setattr("skore_hub_project.artifact.artifact.upload", upload_mock) + monkeypatch.setattr( + "skore_hub_project.artifact.artifact.upload_content", upload_mock + ) @fixture def monkeypatch_upload_routes(respx_mock): + respx_mock.get("projects///artifacts").mock(Response(201, json=[])) respx_mock.post("projects///artifacts").mock( Response(201, json=[{"upload_url": "http://chunk1.com/", "chunk_id": 1}]) ) diff --git a/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py b/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py index 373e6c233..65b4dd5ca 100644 --- a/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py +++ b/skore-hub-project/tests/unit/artifact/media/test_feature_importance.py @@ -10,7 +10,7 @@ PermutationTest, PermutationTrain, ) -from skore_hub_project.artifact.serializer import Serializer +from skore_hub_project.artifact.serializer import TxtSerializer def serialize(result) -> bytes: @@ -98,11 +98,16 @@ def test_feature_importance( result = function(**function_kwargs) content = serialize(result) - with Serializer(content) as serializer: + with TxtSerializer(content) as serializer: checksum = serializer.checksum # available accessor - assert Media(project=project, report=report).model_dump() == { + media = Media(project=project, report=report) + media.upload() # pool=ThreadPoolExecutor, checksums_being_uploaded={}) + + media_payload = media.model_dump() + + assert media_payload == { "content_type": "application/vnd.dataframe", "name": accessor, "data_source": data_source, @@ -114,8 +119,10 @@ def test_feature_importance( assert not upload_mock.call_args.args assert upload_mock.call_args.kwargs == { "project": project, + "serializer_cls": TxtSerializer, "content": content, "content_type": "application/vnd.dataframe", + "checksums_being_uploaded": {checksum}, } # unavailable accessor diff --git a/skore-hub-project/tests/unit/artifact/media/test_model.py b/skore-hub-project/tests/unit/artifact/media/test_model.py index c60de5c1a..c24c65126 100644 --- a/skore-hub-project/tests/unit/artifact/media/test_model.py +++ b/skore-hub-project/tests/unit/artifact/media/test_model.py @@ -19,7 +19,9 @@ def test_estimator_html_repr(respx_mock, binary_classification, upload_mock): # create media project = Project("", "") media = EstimatorHtmlRepr(project=project, report=binary_classification) - media_dict = media.model_dump() + media.upload() + + media_payload = media.model_dump() # ensure `upload` is well called assert upload_mock.called @@ -31,7 +33,7 @@ def test_estimator_html_repr(respx_mock, binary_classification, upload_mock): } # ensure payload is well constructed - assert media_dict == { + assert media_payload == { "content_type": "text/html", "name": "estimator_html_repr", "data_source": None, diff --git a/skore-hub-project/tests/unit/metric/test_accuracy.py b/skore-hub-project/tests/unit/metric/test_accuracy.py deleted file mode 100644 index eb986dd68..000000000 --- a/skore-hub-project/tests/unit/metric/test_accuracy.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - AccuracyTest, - AccuracyTestMean, - AccuracyTestStd, - AccuracyTrain, - AccuracyTrainMean, - AccuracyTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "binary_classification", - AccuracyTrain, - "accuracy", - "Accuracy", - True, - "train", - None, - 1.0, - id="AccuracyTrain", - ), - param( - "binary_classification", - AccuracyTest, - "accuracy", - "Accuracy", - True, - "test", - None, - 0.9, - id="AccuracyTest", - ), - param( - "cv_binary_classification", - AccuracyTrainMean, - "accuracy_mean", - "Accuracy - MEAN", - True, - "train", - None, - 1.0, - id="AccuracyTrainMean", - ), - param( - "cv_binary_classification", - AccuracyTestMean, - "accuracy_mean", - "Accuracy - MEAN", - True, - "test", - None, - 0.93, - id="AccuracyTestMean", - ), - param( - "cv_binary_classification", - AccuracyTrainStd, - "accuracy_std", - "Accuracy - STD", - False, - "train", - None, - 0.0, - id="AccuracyTrainStd", - ), - param( - "cv_binary_classification", - AccuracyTestStd, - "accuracy_std", - "Accuracy - STD", - False, - "test", - None, - 0.04472135954999579, - id="AccuracyTestStd", - ), - ), -) -def test_accuracy( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "accuracy") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_brier_score.py b/skore-hub-project/tests/unit/metric/test_brier_score.py deleted file mode 100644 index 41e2e580d..000000000 --- a/skore-hub-project/tests/unit/metric/test_brier_score.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - BrierScoreTest, - BrierScoreTestMean, - BrierScoreTestStd, - BrierScoreTrain, - BrierScoreTrainMean, - BrierScoreTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "binary_classification", - BrierScoreTrain, - "brier_score", - "Brier score", - False, - "train", - None, - 0.007277500000000001, - id="BrierScoreTrain", - ), - param( - "binary_classification", - BrierScoreTest, - "brier_score", - "Brier score", - False, - "test", - None, - 0.09025999999999999, - id="BrierScoreTest", - ), - param( - "cv_binary_classification", - BrierScoreTrainMean, - "brier_score_mean", - "Brier score - MEAN", - False, - "train", - None, - 0.008439499999999999, - id="BrierScoreTrainMean", - ), - param( - "cv_binary_classification", - BrierScoreTestMean, - "brier_score_mean", - "Brier score - MEAN", - False, - "test", - None, - 0.060865999999999996, - id="BrierScoreTestMean", - ), - param( - "cv_binary_classification", - BrierScoreTrainStd, - "brier_score_std", - "Brier score - STD", - False, - "train", - None, - 0.0004868365678027891, - id="BrierScoreTrainStd", - ), - param( - "cv_binary_classification", - BrierScoreTestStd, - "brier_score_std", - "Brier score - STD", - False, - "test", - None, - 0.015918303694175455, - id="BrierScoreTestStd", - ), - ), -) -def test_brier_score( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "brier_score") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_log_loss.py b/skore-hub-project/tests/unit/metric/test_log_loss.py deleted file mode 100644 index 3e50f1b7e..000000000 --- a/skore-hub-project/tests/unit/metric/test_log_loss.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - LogLossTest, - LogLossTestMean, - LogLossTestStd, - LogLossTrain, - LogLossTrainMean, - LogLossTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "binary_classification", - LogLossTrain, - "log_loss", - "Log loss", - False, - "train", - 4, - 0.06911280690412243, - id="LogLossTrain", - ), - param( - "binary_classification", - LogLossTest, - "log_loss", - "Log loss", - False, - "test", - 4, - 0.3168690248138036, - id="LogLossTest", - ), - param( - "cv_binary_classification", - LogLossTrainMean, - "log_loss_mean", - "Log loss - MEAN", - False, - "train", - 4, - 0.07706302057996325, - id="LogLossTrainMean", - ), - param( - "cv_binary_classification", - LogLossTestMean, - "log_loss_mean", - "Log loss - MEAN", - False, - "test", - 4, - 0.23938382759923754, - id="LogLossTestMean", - ), - param( - "cv_binary_classification", - LogLossTrainStd, - "log_loss_std", - "Log loss - STD", - False, - "train", - None, - 0.003611541148136149, - id="LogLossTrainStd", - ), - param( - "cv_binary_classification", - LogLossTestStd, - "log_loss_std", - "Log loss - STD", - False, - "test", - None, - 0.030545861791452432, - id="LogLossTestStd", - ), - ), -) -def test_log_loss( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "log_loss") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_precision.py b/skore-hub-project/tests/unit/metric/test_precision.py deleted file mode 100644 index 6a7dec3ef..000000000 --- a/skore-hub-project/tests/unit/metric/test_precision.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - PrecisionTest, - PrecisionTestMean, - PrecisionTestStd, - PrecisionTrain, - PrecisionTrainMean, - PrecisionTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "binary_classification", - PrecisionTrain, - "precision", - "Precision (macro)", - True, - "train", - None, - 1.0, - id="PrecisionTrain", - ), - param( - "binary_classification", - PrecisionTest, - "precision", - "Precision (macro)", - True, - "test", - None, - 0.8888888888888888, - id="PrecisionTest", - ), - param( - "cv_binary_classification", - PrecisionTrainMean, - "precision_mean", - "Precision (macro) - MEAN", - True, - "train", - None, - 1.0, - id="PrecisionTrainMean", - ), - param( - "cv_binary_classification", - PrecisionTestMean, - "precision_mean", - "Precision (macro) - MEAN", - True, - "test", - None, - 0.9343434343434345, - id="PrecisionTestMean", - ), - param( - "cv_binary_classification", - PrecisionTrainStd, - "precision_std", - "Precision (macro) - STD", - False, - "train", - None, - 0.0, - id="PrecisionTrainStd", - ), - param( - "cv_binary_classification", - PrecisionTestStd, - "precision_std", - "Precision (macro) - STD", - False, - "test", - None, - 0.045173090454541195, - id="PrecisionTestStd", - ), - ), -) -def test_precision( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "precision") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_r2.py b/skore-hub-project/tests/unit/metric/test_r2.py deleted file mode 100644 index 4c50442ca..000000000 --- a/skore-hub-project/tests/unit/metric/test_r2.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - R2Test, - R2TestMean, - R2TestStd, - R2Train, - R2TrainMean, - R2TrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "regression", - R2Train, - "r2", - "R²", - True, - "train", - None, - 0.9997075936149707, - id="R2Train", - ), - param( - "regression", - R2Test, - "r2", - "R²", - True, - "test", - None, - 0.6757085221095596, - id="R2Test", - ), - param( - "cv_regression", - R2TrainMean, - "r2_mean", - "R² - MEAN", - True, - "train", - None, - 0.9996757990105992, - id="R2TrainMean", - ), - param( - "cv_regression", - R2TestMean, - "r2_mean", - "R² - MEAN", - True, - "test", - None, - 0.7999871077321583, - id="R2TestMean", - ), - param( - "cv_regression", - R2TrainStd, - "r2_std", - "R² - STD", - False, - "train", - None, - 4.8687384179271224e-05, - id="R2TrainStd", - ), - param( - "cv_regression", - R2TestStd, - "r2_std", - "R² - STD", - False, - "test", - None, - 0.0528129534702117, - id="R2TestStd", - ), - ), -) -def test_r2( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "r2") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_recall.py b/skore-hub-project/tests/unit/metric/test_recall.py deleted file mode 100644 index 684f16563..000000000 --- a/skore-hub-project/tests/unit/metric/test_recall.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - RecallTest, - RecallTestMean, - RecallTestStd, - RecallTrain, - RecallTrainMean, - RecallTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "binary_classification", - RecallTrain, - "recall", - "Recall (macro)", - True, - "train", - None, - 1.0, - id="RecallTrain", - ), - param( - "binary_classification", - RecallTest, - "recall", - "Recall (macro)", - True, - "test", - None, - 0.9230769230769231, - id="RecallTest", - ), - param( - "cv_binary_classification", - RecallTrainMean, - "recall_mean", - "Recall (macro) - MEAN", - True, - "train", - None, - 1.0, - id="RecallTrainMean", - ), - param( - "cv_binary_classification", - RecallTestMean, - "recall_mean", - "Recall (macro) - MEAN", - True, - "test", - None, - 0.93, - id="RecallTestMean", - ), - param( - "cv_binary_classification", - RecallTrainStd, - "recall_std", - "Recall (macro) - STD", - False, - "train", - None, - 0.0, - id="RecallTrainStd", - ), - param( - "cv_binary_classification", - RecallTestStd, - "recall_std", - "Recall (macro) - STD", - False, - "test", - None, - 0.04472135954999573, - id="RecallTestStd", - ), - ), -) -def test_recall( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "recall") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_rmse.py b/skore-hub-project/tests/unit/metric/test_rmse.py deleted file mode 100644 index df46b402c..000000000 --- a/skore-hub-project/tests/unit/metric/test_rmse.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - RmseTest, - RmseTestMean, - RmseTestStd, - RmseTrain, - RmseTrainMean, - RmseTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "regression", - RmseTrain, - "rmse", - "RMSE", - False, - "train", - 3, - 2.4355616448506994, - id="RmseTrain", - ), - param( - "regression", - RmseTest, - "rmse", - "RMSE", - False, - "test", - 3, - 73.15561429220227, - id="RmseTest", - ), - param( - "cv_regression", - RmseTrainMean, - "rmse_mean", - "RMSE - MEAN", - False, - "train", - 3, - 2.517350366068551, - id="RmseTrainMean", - ), - param( - "cv_regression", - RmseTestMean, - "rmse_mean", - "RMSE - MEAN", - False, - "test", - 3, - 61.4227542951946, - id="RmseTestMean", - ), - param( - "cv_regression", - RmseTrainStd, - "rmse_std", - "RMSE - STD", - False, - "train", - None, - 0.19476154817956934, - id="RmseTrainStd", - ), - param( - "cv_regression", - RmseTestStd, - "rmse_std", - "RMSE - STD", - False, - "test", - None, - 12.103352184447452, - id="RmseTestStd", - ), - ), -) -def test_rmse( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "rmse") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/metric/test_roc_auc.py b/skore-hub-project/tests/unit/metric/test_roc_auc.py deleted file mode 100644 index 57bff9c5a..000000000 --- a/skore-hub-project/tests/unit/metric/test_roc_auc.py +++ /dev/null @@ -1,132 +0,0 @@ -from __future__ import annotations - -from numpy.testing import assert_almost_equal -from pydantic import ValidationError -from pytest import mark, param, raises - -from skore_hub_project.metric import ( - RocAucTest, - RocAucTestMean, - RocAucTestStd, - RocAucTrain, - RocAucTrainMean, - RocAucTrainStd, -) - - -@mark.parametrize( - "report,Metric,name,verbose_name,greater_is_better,data_source,position,value", - ( - param( - "binary_classification", - RocAucTrain, - "roc_auc", - "ROC AUC", - True, - "train", - 3, - 1.0, - id="RocAucTrain", - ), - param( - "binary_classification", - RocAucTest, - "roc_auc", - "ROC AUC", - True, - "test", - 3, - 0.989010989010989, - id="RocAucTest", - ), - param( - "cv_binary_classification", - RocAucTrainMean, - "roc_auc_mean", - "ROC AUC - MEAN", - True, - "train", - 3, - 1.0, - id="RocAucTrainMean", - ), - param( - "cv_binary_classification", - RocAucTestMean, - "roc_auc_mean", - "ROC AUC - MEAN", - True, - "test", - 3, - 0.986, - id="RocAucTestMean", - ), - param( - "cv_binary_classification", - RocAucTrainStd, - "roc_auc_std", - "ROC AUC - STD", - False, - "train", - None, - 5.551115123125783e-17, - id="RocAucTrainStd", - ), - param( - "cv_binary_classification", - RocAucTestStd, - "roc_auc_std", - "ROC AUC - STD", - False, - "test", - None, - 0.015165750888103078, - id="RocAucTestStd", - ), - ), -) -def test_roc_auc( - monkeypatch, - report, - Metric, - name, - verbose_name, - greater_is_better, - data_source, - position, - value, - request, -): - report = request.getfixturevalue(report) - - # available accessor - metric = Metric(report=report).model_dump() - metric_value = metric.pop("value") - - assert_almost_equal(metric_value, value) - assert metric == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - } - - # unavailable accessor - monkeypatch.delattr(report.metrics.__class__, "roc_auc") - - assert Metric(report=report).model_dump() == { - "name": name, - "verbose_name": verbose_name, - "greater_is_better": greater_is_better, - "data_source": data_source, - "position": position, - "value": None, - } - - # wrong type - with raises( - ValidationError, - match=f"Input should be an instance of {report.__class__.__name__}", - ): - Metric(report=None) diff --git a/skore-hub-project/tests/unit/test_metric.py b/skore-hub-project/tests/unit/test_metric.py new file mode 100644 index 000000000..44de27ed1 --- /dev/null +++ b/skore-hub-project/tests/unit/test_metric.py @@ -0,0 +1,726 @@ +from __future__ import annotations + +from numpy.testing import assert_almost_equal +from pydantic import ValidationError +from pytest import mark, param, raises + +from skore_hub_project.metric import ( + AccuracyTest, + AccuracyTestMean, + AccuracyTestStd, + AccuracyTrain, + AccuracyTrainMean, + AccuracyTrainStd, + BrierScoreTest, + BrierScoreTestMean, + BrierScoreTestStd, + BrierScoreTrain, + BrierScoreTrainMean, + BrierScoreTrainStd, + LogLossTest, + LogLossTestMean, + LogLossTestStd, + LogLossTrain, + LogLossTrainMean, + LogLossTrainStd, + PrecisionTest, + PrecisionTestMean, + PrecisionTestStd, + PrecisionTrain, + PrecisionTrainMean, + PrecisionTrainStd, + R2Test, + R2TestMean, + R2TestStd, + R2Train, + R2TrainMean, + R2TrainStd, + RecallTest, + RecallTestMean, + RecallTestStd, + RecallTrain, + RecallTrainMean, + RecallTrainStd, + RmseTest, + RmseTestMean, + RmseTestStd, + RmseTrain, + RmseTrainMean, + RmseTrainStd, + RocAucTest, + RocAucTestMean, + RocAucTestStd, + RocAucTrain, + RocAucTrainMean, + RocAucTrainStd, +) + + +@mark.parametrize( + "report,Metric,accessor,name,verbose_name,greater_is_better,data_source,position,value", + ( + param( + "binary_classification", + AccuracyTrain, + "accuracy", + "accuracy", + "Accuracy", + True, + "train", + None, + 1.0, + id="AccuracyTrain", + ), + param( + "binary_classification", + AccuracyTest, + "accuracy", + "accuracy", + "Accuracy", + True, + "test", + None, + 0.9, + id="AccuracyTest", + ), + param( + "cv_binary_classification", + AccuracyTrainMean, + "accuracy", + "accuracy_mean", + "Accuracy - MEAN", + True, + "train", + None, + 1.0, + id="AccuracyTrainMean", + ), + param( + "cv_binary_classification", + AccuracyTestMean, + "accuracy", + "accuracy_mean", + "Accuracy - MEAN", + True, + "test", + None, + 0.93, + id="AccuracyTestMean", + ), + param( + "cv_binary_classification", + AccuracyTrainStd, + "accuracy", + "accuracy_std", + "Accuracy - STD", + False, + "train", + None, + 0.0, + id="AccuracyTrainStd", + ), + param( + "cv_binary_classification", + AccuracyTestStd, + "accuracy", + "accuracy_std", + "Accuracy - STD", + False, + "test", + None, + 0.04472135954999579, + id="AccuracyTestStd", + ), + param( + "binary_classification", + BrierScoreTrain, + "brier_score", + "brier_score", + "Brier score", + False, + "train", + None, + 0.007277500000000001, + id="BrierScoreTrain", + ), + param( + "binary_classification", + BrierScoreTest, + "brier_score", + "brier_score", + "Brier score", + False, + "test", + None, + 0.09025999999999999, + id="BrierScoreTest", + ), + param( + "cv_binary_classification", + BrierScoreTrainMean, + "brier_score", + "brier_score_mean", + "Brier score - MEAN", + False, + "train", + None, + 0.008439499999999999, + id="BrierScoreTrainMean", + ), + param( + "cv_binary_classification", + BrierScoreTestMean, + "brier_score", + "brier_score_mean", + "Brier score - MEAN", + False, + "test", + None, + 0.060865999999999996, + id="BrierScoreTestMean", + ), + param( + "cv_binary_classification", + BrierScoreTrainStd, + "brier_score", + "brier_score_std", + "Brier score - STD", + False, + "train", + None, + 0.0004868365678027891, + id="BrierScoreTrainStd", + ), + param( + "cv_binary_classification", + BrierScoreTestStd, + "brier_score", + "brier_score_std", + "Brier score - STD", + False, + "test", + None, + 0.015918303694175455, + id="BrierScoreTestStd", + ), + param( + "binary_classification", + LogLossTrain, + "log_loss", + "log_loss", + "Log loss", + False, + "train", + 4, + 0.06911280690412243, + id="LogLossTrain", + ), + param( + "binary_classification", + LogLossTest, + "log_loss", + "log_loss", + "Log loss", + False, + "test", + 4, + 0.3168690248138036, + id="LogLossTest", + ), + param( + "cv_binary_classification", + LogLossTrainMean, + "log_loss", + "log_loss_mean", + "Log loss - MEAN", + False, + "train", + 4, + 0.07706302057996325, + id="LogLossTrainMean", + ), + param( + "cv_binary_classification", + LogLossTestMean, + "log_loss", + "log_loss_mean", + "Log loss - MEAN", + False, + "test", + 4, + 0.23938382759923754, + id="LogLossTestMean", + ), + param( + "cv_binary_classification", + LogLossTrainStd, + "log_loss", + "log_loss_std", + "Log loss - STD", + False, + "train", + None, + 0.003611541148136149, + id="LogLossTrainStd", + ), + param( + "cv_binary_classification", + LogLossTestStd, + "log_loss", + "log_loss_std", + "Log loss - STD", + False, + "test", + None, + 0.030545861791452432, + id="LogLossTestStd", + ), + param( + "binary_classification", + PrecisionTrain, + "precision", + "precision", + "Precision (macro)", + True, + "train", + None, + 1.0, + id="PrecisionTrain", + ), + param( + "binary_classification", + PrecisionTest, + "precision", + "precision", + "Precision (macro)", + True, + "test", + None, + 0.8888888888888888, + id="PrecisionTest", + ), + param( + "cv_binary_classification", + PrecisionTrainMean, + "precision", + "precision_mean", + "Precision (macro) - MEAN", + True, + "train", + None, + 1.0, + id="PrecisionTrainMean", + ), + param( + "cv_binary_classification", + PrecisionTestMean, + "precision", + "precision_mean", + "Precision (macro) - MEAN", + True, + "test", + None, + 0.9343434343434345, + id="PrecisionTestMean", + ), + param( + "cv_binary_classification", + PrecisionTrainStd, + "precision", + "precision_std", + "Precision (macro) - STD", + False, + "train", + None, + 0.0, + id="PrecisionTrainStd", + ), + param( + "cv_binary_classification", + PrecisionTestStd, + "precision", + "precision_std", + "Precision (macro) - STD", + False, + "test", + None, + 0.045173090454541195, + id="PrecisionTestStd", + ), + param( + "regression", + R2Train, + "r2", + "r2", + "R²", + True, + "train", + None, + 0.9997075936149707, + id="R2Train", + ), + param( + "regression", + R2Test, + "r2", + "r2", + "R²", + True, + "test", + None, + 0.6757085221095596, + id="R2Test", + ), + param( + "cv_regression", + R2TrainMean, + "r2", + "r2_mean", + "R² - MEAN", + True, + "train", + None, + 0.9996757990105992, + id="R2TrainMean", + ), + param( + "cv_regression", + R2TestMean, + "r2", + "r2_mean", + "R² - MEAN", + True, + "test", + None, + 0.7999871077321583, + id="R2TestMean", + ), + param( + "cv_regression", + R2TrainStd, + "r2", + "r2_std", + "R² - STD", + False, + "train", + None, + 4.8687384179271224e-05, + id="R2TrainStd", + ), + param( + "cv_regression", + R2TestStd, + "r2", + "r2_std", + "R² - STD", + False, + "test", + None, + 0.0528129534702117, + id="R2TestStd", + ), + param( + "binary_classification", + RecallTrain, + "recall", + "recall", + "Recall (macro)", + True, + "train", + None, + 1.0, + id="RecallTrain", + ), + param( + "binary_classification", + RecallTest, + "recall", + "recall", + "Recall (macro)", + True, + "test", + None, + 0.9230769230769231, + id="RecallTest", + ), + param( + "cv_binary_classification", + RecallTrainMean, + "recall", + "recall_mean", + "Recall (macro) - MEAN", + True, + "train", + None, + 1.0, + id="RecallTrainMean", + ), + param( + "cv_binary_classification", + RecallTestMean, + "recall", + "recall_mean", + "Recall (macro) - MEAN", + True, + "test", + None, + 0.93, + id="RecallTestMean", + ), + param( + "cv_binary_classification", + RecallTrainStd, + "recall", + "recall_std", + "Recall (macro) - STD", + False, + "train", + None, + 0.0, + id="RecallTrainStd", + ), + param( + "cv_binary_classification", + RecallTestStd, + "recall", + "recall_std", + "Recall (macro) - STD", + False, + "test", + None, + 0.04472135954999573, + id="RecallTestStd", + ), + param( + "regression", + RmseTrain, + "rmse", + "rmse", + "RMSE", + False, + "train", + 3, + 2.4355616448506994, + id="RmseTrain", + ), + param( + "regression", + RmseTest, + "rmse", + "rmse", + "RMSE", + False, + "test", + 3, + 73.15561429220227, + id="RmseTest", + ), + param( + "cv_regression", + RmseTrainMean, + "rmse", + "rmse_mean", + "RMSE - MEAN", + False, + "train", + 3, + 2.517350366068551, + id="RmseTrainMean", + ), + param( + "cv_regression", + RmseTestMean, + "rmse", + "rmse_mean", + "RMSE - MEAN", + False, + "test", + 3, + 61.4227542951946, + id="RmseTestMean", + ), + param( + "cv_regression", + RmseTrainStd, + "rmse", + "rmse_std", + "RMSE - STD", + False, + "train", + None, + 0.19476154817956934, + id="RmseTrainStd", + ), + param( + "cv_regression", + RmseTestStd, + "rmse", + "rmse_std", + "RMSE - STD", + False, + "test", + None, + 12.103352184447452, + id="RmseTestStd", + ), + param( + "binary_classification", + RocAucTrain, + "roc_auc", + "roc_auc", + "ROC AUC", + True, + "train", + 3, + 1.0, + id="RocAucTrain", + ), + param( + "binary_classification", + RocAucTest, + "roc_auc", + "roc_auc", + "ROC AUC", + True, + "test", + 3, + 0.989010989010989, + id="RocAucTest", + ), + param( + "cv_binary_classification", + RocAucTrainMean, + "roc_auc", + "roc_auc_mean", + "ROC AUC - MEAN", + True, + "train", + 3, + 1.0, + id="RocAucTrainMean", + ), + param( + "cv_binary_classification", + RocAucTestMean, + "roc_auc", + "roc_auc_mean", + "ROC AUC - MEAN", + True, + "test", + 3, + 0.986, + id="RocAucTestMean", + ), + param( + "cv_binary_classification", + RocAucTrainStd, + "roc_auc", + "roc_auc_std", + "ROC AUC - STD", + False, + "train", + None, + 5.551115123125783e-17, + id="RocAucTrainStd", + ), + param( + "cv_binary_classification", + RocAucTestStd, + "roc_auc", + "roc_auc_std", + "ROC AUC - STD", + False, + "test", + None, + 0.015165750888103078, + id="RocAucTestStd", + ), + ), +) +class TestMetric: + def test_metric_available_accessor( + self, + monkeypatch, + report, + Metric, + accessor, + name, + verbose_name, + greater_is_better, + data_source, + position, + value, + request, + ): + report = request.getfixturevalue(report) + + # available accessor + metric = Metric(report=report) + metric.compute() + + metric_payload = metric.model_dump() + + assert_almost_equal(metric_payload.pop("value"), value) + assert metric_payload == { + "name": name, + "verbose_name": verbose_name, + "greater_is_better": greater_is_better, + "data_source": data_source, + "position": position, + } + + def test_metric_unavailable_accessor( + self, + monkeypatch, + report, + Metric, + accessor, + name, + verbose_name, + greater_is_better, + data_source, + position, + value, + request, + ): + report = request.getfixturevalue(report) + + # unavailable accessor + monkeypatch.delattr(report.metrics.__class__, accessor) + + metric = Metric(report=report) + metric.compute() + + metric_payload = metric.model_dump() + + assert metric_payload == { + "name": name, + "verbose_name": verbose_name, + "greater_is_better": greater_is_better, + "data_source": data_source, + "position": position, + "value": None, + } + + def test_metric_exception( + self, + monkeypatch, + report, + Metric, + accessor, + name, + verbose_name, + greater_is_better, + data_source, + position, + value, + request, + ): + report = request.getfixturevalue(report) + + # wrong type + with raises( + ValidationError, + match=f"Input should be an instance of {report.__class__.__name__}", + ): + Metric(report=None)