Skip to content

Fix Periodogram pickling #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

--
- A problem with pickling of `Periodogram` which caused wrong results from `.power` and `.freq_power` for a deserialized
object https://github.com/light-curve/light-curve-python/pull/532

### Security

Expand Down
87 changes: 64 additions & 23 deletions light-curve/src/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Deref;
// Details of pickle support implementation
// ----------------------------------------
// [PyFeatureEvaluator] implements __getstate__ and __setstate__ required for pickle serialisation,
Expand Down Expand Up @@ -588,28 +589,6 @@
self.feature_evaluator_f64.get_descriptions()
}

/// Used by pickle.load / pickle.loads
fn __setstate__(&mut self, state: Bound<PyBytes>) -> Res<()> {
*self = serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
.map_err(|err| {
Exception::UnpicklingError(format!(
r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
))
})?;
Ok(())
}

/// Used by pickle.dump / pickle.dumps
fn __getstate__<'py>(&self, py: Python<'py>) -> Res<Bound<'py, PyBytes>> {
let vec_bytes =
serde_pickle::to_vec(&self, serde_pickle::SerOptions::new()).map_err(|err| {
Exception::PicklingError(format!(
r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
))
})?;
Ok(PyBytes::new(py, &vec_bytes))
}

/// Used by copy.copy
fn __copy__(&self) -> Self {
self.clone()
Expand All @@ -621,9 +600,43 @@
}
}

macro_rules! impl_pickle_serialisation {
($name: ident) => {
#[pymethods]
impl $name {
/// Used by pickle.load / pickle.loads
fn __setstate__(mut slf: PyRefMut<'_, Self>, state: Bound<PyBytes>) -> Res<()> {
let (super_rust, self_rust): (PyFeatureEvaluator, Self) = serde_pickle::from_slice(state.as_bytes(), serde_pickle::DeOptions::new())
.map_err(|err| {
Exception::UnpicklingError(format!(
r#"Error happened on the Rust side when deserializing _FeatureEvaluator: "{err}""#
))

Check warning on line 613 in light-curve/src/features.rs

View check run for this annotation

Codecov / codecov/patch

light-curve/src/features.rs#L611-L613

Added lines #L611 - L613 were not covered by tests
})?;
*slf.as_mut() = super_rust;
*slf = self_rust;
Ok(())
}

/// Used by pickle.dump / pickle.dumps
fn __getstate__<'py>(slf: PyRef<'py, Self>) -> Res<Bound<'py, PyBytes>> {
let supr = slf.as_super();
let vec_bytes = serde_pickle::to_vec(&(supr.deref(), slf.deref()), serde_pickle::SerOptions::new()).map_err(|err| {
Exception::PicklingError(format!(
r#"Error happened on the Rust side when serializing _FeatureEvaluator: "{err}""#
))

Check warning on line 626 in light-curve/src/features.rs

View check run for this annotation

Codecov / codecov/patch

light-curve/src/features.rs#L624-L626

Added lines #L624 - L626 were not covered by tests
})?;
Ok(PyBytes::new(slf.py(), &vec_bytes))
}
}
}
}

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct Extractor {}

impl_pickle_serialisation!(Extractor);

#[pymethods]
impl Extractor {
#[new]
Expand Down Expand Up @@ -702,11 +715,14 @@

macro_rules! evaluator {
($name: ident, $eval: ty, $default_transform: expr $(,)?) => {
#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct $name {}

impl_stock_transform!($name, $default_transform);

impl_pickle_serialisation!($name);

#[pymethods]
impl $name {
#[new]
Expand Down Expand Up @@ -806,9 +822,12 @@

macro_rules! fit_evaluator {
($name: ident, $eval: ty, $ib: ty, $transform: expr, $nparam: literal, $ln_prior_by_str: tt, $ln_prior_doc: literal $(,)?) => {
#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct $name {}

impl_pickle_serialisation!($name);

impl $name {
fn supported_algorithms_str() -> String {
return SUPPORTED_ALGORITHMS_CURVE_FIT.join(", ");
Expand Down Expand Up @@ -1051,7 +1070,7 @@
Number of Ceres iterations, default is {niter}
ceres_loss_reg : float, optional
Ceres loss regularization, default is to use square norm as is, if set to
a number, the loss function is reqgualized to descriminate outlier
a number, the loss function is regularized to descriminate outlier
residuals larger than this value.
Default is None which means no regularization.
"#,
Expand Down Expand Up @@ -1158,10 +1177,12 @@
StockTransformer::Lg
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct BeyondNStd {}

impl_stock_transform!(BeyondNStd, StockTransformer::Identity);
impl_pickle_serialisation!(BeyondNStd);

#[pymethods]
impl BeyondNStd {
Expand Down Expand Up @@ -1219,9 +1240,12 @@
"'no': no prior",
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct Bins {}

impl_pickle_serialisation!(Bins);

#[pymethods]
impl Bins {
#[new]
Expand Down Expand Up @@ -1318,10 +1342,12 @@
StockTransformer::Identity
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct InterPercentileRange {}

impl_stock_transform!(InterPercentileRange, StockTransformer::Identity);
impl_pickle_serialisation!(InterPercentileRange);

#[pymethods]
impl InterPercentileRange {
Expand Down Expand Up @@ -1385,10 +1411,12 @@
"'no': no prior",
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct MagnitudePercentageRatio {}

impl_stock_transform!(MagnitudePercentageRatio, StockTransformer::Identity);
impl_pickle_serialisation!(MagnitudePercentageRatio);

#[pymethods]
impl MagnitudePercentageRatio {
Expand Down Expand Up @@ -1474,10 +1502,12 @@
StockTransformer::Identity
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct MedianBufferRangePercentage {}

impl_stock_transform!(MedianBufferRangePercentage, StockTransformer::Identity);
impl_pickle_serialisation!(MedianBufferRangePercentage);

#[pymethods]
impl MedianBufferRangePercentage {
Expand Down Expand Up @@ -1526,13 +1556,15 @@
StockTransformer::Identity
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct PercentDifferenceMagnitudePercentile {}

impl_stock_transform!(
PercentDifferenceMagnitudePercentile,
StockTransformer::ClippedLg
);
impl_pickle_serialisation!(PercentDifferenceMagnitudePercentile);

#[pymethods]
impl PercentDifferenceMagnitudePercentile {
Expand Down Expand Up @@ -1588,12 +1620,15 @@
Float(f32),
}

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct Periodogram {
eval_f32: LcfPeriodogram<f32>,
eval_f64: LcfPeriodogram<f64>,
}

impl_pickle_serialisation!(Periodogram);

impl Periodogram {
fn create_evals(
peaks: Option<usize>,
Expand Down Expand Up @@ -2005,9 +2040,12 @@
StockTransformer::Identity
);

#[derive(Serialize, Deserialize)]
#[pyclass(extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct OtsuSplit {}

impl_pickle_serialisation!(OtsuSplit);

#[pymethods]
impl OtsuSplit {
#[new]
Expand Down Expand Up @@ -2066,9 +2104,12 @@
);

/// Feature evaluator deserialized from JSON string
#[derive(Serialize, Deserialize)]
#[pyclass(name = "JSONDeserializedFeature", extends = PyFeatureEvaluator, module="light_curve.light_curve_ext")]
pub struct JsonDeserializedFeature {}

impl_pickle_serialisation!(JsonDeserializedFeature);

#[pymethods]
impl JsonDeserializedFeature {
#[new]
Expand Down
Loading
Loading