Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion skore-hub-project/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"numpy",
"orjson",
"pydantic",
"rich",
"rich>=14.2.0",
"scikit-learn",
]

Expand Down
54 changes: 45 additions & 9 deletions skore-hub-project/src/skore_hub_project/artifact/artifact.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
"""Interface definition of the payload used to associate an artifact with a project."""

from __future__ import annotations

from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import AbstractContextManager, nullcontext
from functools import cached_property
from typing import ClassVar

from pydantic import BaseModel, ConfigDict, Field, computed_field

from skore_hub_project import Project
from skore_hub_project.artifact.upload import upload
from skore_hub_project.artifact.serializer import Serializer, TxtSerializer
from skore_hub_project.artifact.upload import upload as upload_content
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport

Content = str | bytes | None
Content = EstimatorReport | CrossValidationReport | str | bytes | None


class Artifact(BaseModel, ABC):
Expand All @@ -29,9 +34,10 @@ class Artifact(BaseModel, ABC):
as a file to the ``hub`` artifacts storage.
"""

model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True)

project: Project = Field(repr=False, exclude=True)
serializer_cls: ClassVar[type[Serializer]] = TxtSerializer
content_type: str = Field(init=False)

@abstractmethod
Expand All @@ -56,20 +62,50 @@ def content_to_upload(self) -> Generator[str, None, None]:
"""

@computed_field # type: ignore[prop-decorator]
@cached_property
@property
def checksum(self) -> str | None:
"""Checksum used to identify the content of the artifact."""
try:
return self.__checksum
except AttributeError:
message = (
"You cannot access the checksum of an artifact "
"without explicitly uploading it. "
"Please use `artifact.upload()` before."
)

raise RuntimeError(message) from None

@checksum.setter
def checksum(self, checksum: str | None) -> None:
self.__checksum = checksum

def upload(
self,
*,
pool: ThreadPoolExecutor | None = None,
checksums_being_uploaded: set[str] | None = None,
) -> None:
"""Upload the artifact and set its checksum."""
contextmanager = self.content_to_upload()
pool = ThreadPoolExecutor(max_workers=6) if pool is None else pool
checksums_being_uploaded = (
set() if checksums_being_uploaded is None else checksums_being_uploaded
)

if not isinstance(contextmanager, AbstractContextManager):
contextmanager = nullcontext(contextmanager)

with contextmanager as content:
if content is not None:
return upload(
self.checksum = (
None
if content is None
else upload_content(
project=self.project,
serializer_cls=self.serializer_cls,
content=content,
content_type=self.content_type,
pool=pool,
checksums_being_uploaded=checksums_being_uploaded,
)

return None
)
17 changes: 6 additions & 11 deletions skore-hub-project/src/skore_hub_project/artifact/pickle/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from collections.abc import Generator
from contextlib import contextmanager
from io import BytesIO
from typing import Literal
from typing import ClassVar, Literal

from joblib import dump
from pydantic import Field

from skore_hub_project.artifact.artifact import Artifact
from skore_hub_project.artifact.serializer import ReportSerializer, Serializer
from skore_hub_project.protocol import CrossValidationReport, EstimatorReport

Report = EstimatorReport | CrossValidationReport
Expand All @@ -35,11 +34,12 @@ class Pickle(Artifact):
The report is primarily pickled on disk to reduce RAM footprint.
"""

report: Report = Field(repr=False, exclude=True)
serializer_cls: ClassVar[type[Serializer]] = ReportSerializer
content_type: Literal["application/octet-stream"] = "application/octet-stream"
report: Report = Field(repr=False, exclude=True)

@contextmanager
def content_to_upload(self) -> Generator[bytes, None, None]:
def content_to_upload(self) -> Generator[Report, None, None]:
"""
Content of the pickled report.

Expand All @@ -53,12 +53,7 @@ def content_to_upload(self) -> Generator[bytes, None, None]:
self.report.clear_cache()

try:
with BytesIO() as stream:
dump(self.report, stream)

pickle_bytes = stream.getvalue()

yield pickle_bytes
yield self.report
finally:
for report, cache in zip(reports, caches, strict=True):
report._cache = cache
72 changes: 65 additions & 7 deletions skore-hub-project/src/skore_hub_project/artifact/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,37 @@

from __future__ import annotations

from abc import ABC, abstractmethod
from functools import cached_property
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any

from blake3 import blake3 as Blake3
from joblib import dump

from skore_hub_project.protocol import CrossValidationReport, EstimatorReport

class Serializer:
"""Serialize a content directly on disk to reduce RAM footprint."""

def __init__(self, content: str | bytes):
if isinstance(content, str):
self.filepath.write_text(content, encoding="utf-8")
else:
self.filepath.write_bytes(content)
class Serializer(ABC):
"""Abstract class to serialize anything on disk."""

called: bool = False

@abstractmethod
def __init__(self, _: Any, /): ...

def __enter__(self) -> Serializer: # noqa: D105
return self

def __exit__(self, *args: Any) -> None: # noqa: D105
self.filepath.unlink(True)

@abstractmethod
def __call__(self) -> None:
"""Serialize anything on disk."""

@cached_property
def filepath(self) -> Path:
"""The filepath used to serialize the content."""
Expand All @@ -47,6 +55,13 @@ def checksum(self) -> str:
"""
from skore_hub_project import bytes_to_b64_str

if not self.called:
raise RuntimeError(
"You cannot access the checksum of a serializer without explicitly "
"called it. Please use `serializer()` before."
)

# Compute checksum with the appropriate number of threads
hasher = Blake3(max_threads=(1 if self.size < 1e6 else Blake3.AUTO))
checksum = hasher.update_mmap(self.filepath).digest()

Expand All @@ -55,4 +70,47 @@ def checksum(self) -> str:
@cached_property
def size(self) -> int:
"""The size of the serialized content, in bytes."""
if not self.called:
raise RuntimeError(
"You cannot access the size of a serializer without explicitly "
"called it. Please use `serializer()` before."
)

return self.filepath.stat().st_size


class TxtSerializer(Serializer):
"""Serialize a str or bytes on disk."""

def __init__(self, txt: str | bytes, /):
if isinstance(txt, str):
txt = txt.encode(encoding="utf-8")

self.filepath.write_bytes(txt)
self.called = True

def __call__(self) -> None:
"""Serialize a str or bytes on disk."""


class ReportSerializer(Serializer):
"""Serialize a report using joblib on disk."""

def __init__(self, report: CrossValidationReport | EstimatorReport, /):
self.report = report

def __call__(self) -> None:
"""Serialize a report using joblib on disk."""
if self.called:
return

with BytesIO() as stream:
dump(self.report, stream)

self.filepath.write_bytes(stream.getvalue())
self.called = True

@cached_property
def checksum(self) -> str:
"""The checksum of the serialized report."""
return f"skore-{self.report.__class__.__name__}-{self.report._hash}"
Loading
Loading