3
3
"""
4
4
5
5
from typing import Literal , cast
6
- import functools
6
+ from functools import partial
7
7
8
8
import numpy as np
9
9
from scipy .interpolate import PPoly
14
14
from ._reshape import prod , to_2d
15
15
from ._types import FloatNDArrayType , MultivariateDataType , UnivariateDataType
16
16
17
+ diags_csr = partial (sp .diags , format = 'csr' )
18
+ vpad = partial (np .pad , pad_width = [(1 , 1 ), (0 , 0 )], mode = 'constant' )
19
+
17
20
18
21
class SplinePPForm (ISplinePPForm [np .ndarray , int ], PPoly ):
19
22
"""The base class for univariate/multivariate spline in piecewise polynomial form
@@ -282,13 +285,13 @@ def _make_spline(x, y, w, smooth, shape, normalizedsmooth):
282
285
283
286
# Create diagonal sparse matrices
284
287
diags_r = np .vstack ((dx [1 :], 2 * (dx [1 :] + dx [:- 1 ]), dx [:- 1 ]))
285
- r = sp .spdiags (diags_r , [- 1 , 0 , 1 ], pcount - 2 , pcount - 2 )
288
+ r = sp .spdiags (diags_r , [- 1 , 0 , 1 ], pcount - 2 , pcount - 2 , format = 'csr' )
286
289
287
290
dx_recip = 1.0 / dx
288
291
diags_qtw = np .vstack ((dx_recip [:- 1 ], - (dx_recip [1 :] + dx_recip [:- 1 ]), dx_recip [1 :]))
289
292
diags_sqrw_recip = 1.0 / np .sqrt (w )
290
293
291
- qtw = sp . diags (diags_qtw , [0 , 1 , 2 ], (pcount - 2 , pcount )) @ sp . diags (diags_sqrw_recip , 0 , (pcount , pcount ))
294
+ qtw = diags_csr (diags_qtw , [0 , 1 , 2 ], (pcount - 2 , pcount )) @ diags_csr (diags_sqrw_recip , 0 , (pcount , pcount ))
292
295
qtw = qtw @ qtw .T
293
296
294
297
p = smooth
@@ -312,13 +315,11 @@ def _make_spline(x, y, w, smooth, shape, normalizedsmooth):
312
315
313
316
dx = dx [:, np .newaxis ]
314
317
315
- vpad = functools .partial (np .pad , pad_width = [(1 , 1 ), (0 , 0 )], mode = 'constant' )
316
-
317
318
d1 = np .diff (vpad (u ), axis = 0 ) / dx
318
319
d2 = np .diff (vpad (d1 ), axis = 0 )
319
320
320
321
diags_w_recip = 1.0 / w
321
- w = sp . diags (diags_w_recip , 0 , (pcount , pcount ))
322
+ w = diags_csr (diags_w_recip , 0 , (pcount , pcount ))
322
323
323
324
yi = y .T - (pp * w ) @ d2
324
325
pu = vpad (p * u )
0 commit comments