Skip to content

Commit 704dc37

Browse files
authored
fix using 'nu' argument in n-d grid spline evaluator (#32)
fix using 'nu' argument for n-d grid spline evaluator * also tests for 'nu' and 'extrapolate' have been added
1 parent 84c448a commit 704dc37

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

csaps/_sspndg.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,19 @@ def __call__(self,
107107
interpolation axis in the original array with the shape of x.
108108
109109
"""
110+
110111
x = ndgrid_prepare_data_vectors(x, 'x', min_size=1)
111112

112113
if len(x) != self.ndim:
113114
raise ValueError(
114115
f"'x' sequence must have length {self.ndim} according to 'breaks'")
115116

117+
if nu is None:
118+
nu = (0,) * len(x)
119+
120+
if extrapolate is None:
121+
extrapolate = True
122+
116123
shape = tuple(x.size for x in x)
117124

118125
coeffs = ndg_coeffs_to_flatten(self.coeffs)
@@ -128,8 +135,9 @@ def __call__(self,
128135
coeffs = coeffs.reshape(c_shape)
129136

130137
coeffs_cnl = umv_coeffs_to_canonical(coeffs, self.pieces[i])
131-
coeffs = PPoly.construct_fast(coeffs_cnl, self.breaks[i],
132-
extrapolate=extrapolate, axis=1)(x[i])
138+
139+
spline = PPoly.construct_fast(coeffs_cnl, self.breaks[i], axis=1)
140+
coeffs = spline(x[i], nu=nu[i], extrapolate=extrapolate)
133141

134142
shape_r = (*coeffs_shape[:ndim_m1], shape[i])
135143
coeffs = coeffs.reshape(shape_r).transpose(permuted_axes)

tests/test_ndg.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
import numpy as np
6+
from scipy.interpolate import NdPPoly
67
import csaps
78

89

@@ -196,3 +197,29 @@ def test_auto_smooth_2d(ndgrid_2d_data):
196197

197198
assert s.smooth == pytest.approx(smooth_expected)
198199
assert zi == pytest.approx(zi_expected)
200+
201+
202+
@pytest.mark.parametrize('nu', [
203+
None,
204+
(0, 0),
205+
(1, 1),
206+
(2, 2),
207+
])
208+
@pytest.mark.parametrize('extrapolate', [
209+
None,
210+
True,
211+
False,
212+
])
213+
def test_evaluate_nu_extrapolate(nu: tuple, extrapolate: bool):
214+
x = ([1, 2, 3, 4], [1, 2, 3, 4])
215+
xi = ([0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])
216+
y = np.arange(4 * 4).reshape((4, 4))
217+
218+
ss = csaps.NdGridCubicSmoothingSpline(x, y, smooth=1.0)
219+
y_ss = ss(xi, nu=nu, extrapolate=extrapolate)
220+
221+
pp = NdPPoly(ss.spline.c, x)
222+
xx = tuple(np.meshgrid(*xi, indexing='ij'))
223+
y_pp = pp(xx, nu=nu, extrapolate=extrapolate)
224+
225+
np.testing.assert_allclose(y_ss, y_pp, rtol=1e-05, atol=1e-08, equal_nan=True)

tests/test_umv.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,19 @@ def test_cubic_bc_natural():
241241

242242
assert cs.c == pytest.approx(ss.spline.c)
243243
assert y_cs == pytest.approx(y_ss)
244+
245+
246+
@pytest.mark.parametrize('nu', [0, 1, 2])
247+
@pytest.mark.parametrize('extrapolate', [None, True, False, 'periodic'])
248+
def test_evaluate_nu_extrapolate(nu, extrapolate):
249+
x = [1, 2, 3, 4]
250+
xi = [0, 1, 2, 3, 4, 5]
251+
y = [1, 2, 3, 4]
252+
253+
cs = CubicSpline(x, y)
254+
y_cs = cs(xi, nu=nu, extrapolate=extrapolate)
255+
256+
ss = csaps.CubicSmoothingSpline(x, y, smooth=1.0)
257+
y_ss = ss(xi, nu=nu, extrapolate=extrapolate)
258+
259+
np.testing.assert_allclose(y_ss, y_cs, rtol=1e-05, atol=1e-08, equal_nan=True)

0 commit comments

Comments
 (0)