Skip to content

Commit 7526b59

Browse files
vidipsinghSaransh-cpppre-commit-ci[bot]
authored
Make npt.NDArray type hints more specific with dtype (#4901)
* Make npt.NDArray type hints more specific with dtype * Refine npt.NDArray type hints per review * Update src/pybamm/solvers/idaklu_jax.py Co-authored-by: Saransh Chopra <[email protected]> * Update src/pybamm/solvers/idaklu_jax.py Co-authored-by: Saransh Chopra <[email protected]> * style: pre-commit fixes * Add from __future__ import annotations to resolve CI type hint issue * Update src/pybamm/solvers/processed_variable_time_integral.py Co-authored-by: Saransh Chopra <[email protected]> * style: pre-commit fixes * Add annotations import to idaklu_jax.py --------- Co-authored-by: Saransh Chopra <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 34186fe commit 7526b59

19 files changed

+71
-58
lines changed

benchmarks/different_model_options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def build_model(parameter, model_, option, value):
3434
class SolveModel:
3535
solver: pybamm.BaseSolver
3636
model: pybamm.BaseModel
37-
t_eval: npt.NDArray
38-
t_interp: npt.NDArray | None
37+
t_eval: npt.NDArray[np.float64]
38+
t_interp: npt.NDArray[np.float64] | None
3939

4040
def solve_setup(self, parameter, model_, option, value, solver_class):
4141
self.solver = solver_class()

benchmarks/time_solve_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ class TimeSolveSPM:
3131
)
3232
model: pybamm.BaseModel
3333
solver: pybamm.BaseSolver
34-
t_eval: npt.NDArray
35-
t_interp: npt.NDArray | None
34+
t_eval: npt.NDArray[np.float64]
35+
t_interp: npt.NDArray[np.float64] | None
3636

3737
def setup(self, solve_first, parameters, solver_class):
3838
set_random_seed()
@@ -97,7 +97,7 @@ class TimeSolveSPMe:
9797
)
9898
model: pybamm.BaseModel
9999
solver: pybamm.BaseSolver
100-
t_eval: npt.NDArray
100+
t_eval: npt.NDArray[np.float64]
101101

102102
def setup(self, solve_first, parameters, solver_class):
103103
set_random_seed()
@@ -161,7 +161,7 @@ class TimeSolveDFN:
161161
)
162162
model: pybamm.BaseModel
163163
solver: pybamm.BaseSolver
164-
t_eval: npt.NDArray
164+
t_eval: npt.NDArray[np.float64]
165165

166166
def setup(self, solve_first, parameters, solver_class):
167167
set_random_seed()

src/pybamm/expression_tree/binary_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol):
153153
def evaluate(
154154
self,
155155
t: float | None = None,
156-
y: npt.NDArray | None = None,
157-
y_dot: npt.NDArray | None = None,
156+
y: npt.NDArray[np.float64] | None = None,
157+
y_dot: npt.NDArray[np.float64] | None = None,
158158
inputs: dict | str | None = None,
159159
):
160160
"""See :meth:`pybamm.Symbol.evaluate()`."""

src/pybamm/expression_tree/concatenations.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sympy
1313
from scipy.sparse import issparse, vstack
1414
from collections.abc import Sequence
15+
from typing import Any
1516

1617
import pybamm
1718

@@ -114,7 +115,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]):
114115

115116
return domains
116117

117-
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
118+
def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]):
118119
"""See :meth:`Concatenation._concatenation_evaluate()`."""
119120
if len(children_eval) == 0:
120121
return np.array([])
@@ -124,8 +125,8 @@ def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
124125
def evaluate(
125126
self,
126127
t: float | None = None,
127-
y: npt.NDArray | None = None,
128-
y_dot: npt.NDArray | None = None,
128+
y: npt.NDArray[np.float64] | None = None,
129+
y_dot: npt.NDArray[np.float64] | None = None,
129130
inputs: dict | str | None = None,
130131
):
131132
"""See :meth:`pybamm.Symbol.evaluate()`."""
@@ -369,7 +370,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict:
369370
start = end
370371
return slices
371372

372-
def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
373+
def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]):
373374
"""See :meth:`Concatenation._concatenation_evaluate()`."""
374375
# preallocate vector
375376
vector = np.empty((self._size, 1))

src/pybamm/expression_tree/discrete_time_sum.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pybamm
22
import numpy.typing as npt
3+
import numpy as np
4+
from typing import Any
35

46

57
class DiscreteTimeData(pybamm.Interpolant):
@@ -19,7 +21,9 @@ class DiscreteTimeData(pybamm.Interpolant):
1921
2022
"""
2123

22-
def __init__(self, time_points: npt.NDArray, data: npt.NDArray, name: str):
24+
def __init__(
25+
self, time_points: npt.NDArray[np.float64], data: npt.NDArray[Any], name: str
26+
):
2327
super().__init__(time_points, data, pybamm.t, name)
2428

2529
def create_copy(self, new_children=None, perform_simplifications=True):

src/pybamm/expression_tree/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def _function_jac(self, children_jacs):
123123
def evaluate(
124124
self,
125125
t: float | None = None,
126-
y: npt.NDArray | None = None,
127-
y_dot: npt.NDArray | None = None,
126+
y: npt.NDArray[np.float64] | None = None,
127+
y_dot: npt.NDArray[np.float64] | None = None,
128128
inputs: dict | str | None = None,
129129
):
130130
"""See :meth:`pybamm.Symbol.evaluate()`."""

src/pybamm/expression_tree/independent_variable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
from __future__ import annotations
55
import sympy
6+
import numpy as np
67
import numpy.typing as npt
78
import pybamm
89
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
@@ -93,8 +94,8 @@ def create_copy(
9394
def _base_evaluate(
9495
self,
9596
t: float | None = None,
96-
y: npt.NDArray | None = None,
97-
y_dot: npt.NDArray | None = None,
97+
y: npt.NDArray[np.float64] | None = None,
98+
y_dot: npt.NDArray[np.float64] | None = None,
9899
inputs: dict | str | None = None,
99100
):
100101
"""See :meth:`pybamm.Symbol._base_evaluate()`."""

src/pybamm/expression_tree/input_parameter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix:
8989
def _base_evaluate(
9090
self,
9191
t: float | None = None,
92-
y: npt.NDArray | None = None,
93-
y_dot: npt.NDArray | None = None,
92+
y: npt.NDArray[np.float64] | None = None,
93+
y_dot: npt.NDArray[np.float64] | None = None,
9494
inputs: dict | str | None = None,
9595
):
9696
# inputs should be a dictionary

src/pybamm/expression_tree/interpolant.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from scipy import interpolate
88
from collections.abc import Sequence
99
import numbers
10+
from typing import Any
1011

1112
import pybamm
1213

@@ -44,8 +45,8 @@ class Interpolant(pybamm.Function):
4445

4546
def __init__(
4647
self,
47-
x: npt.NDArray | Sequence[npt.NDArray],
48-
y: npt.NDArray,
48+
x: npt.NDArray[np.float64] | Sequence[npt.NDArray[np.float64]],
49+
y: npt.NDArray[Any],
4950
children: Sequence[pybamm.Symbol] | pybamm.Time,
5051
name: str | None = None,
5152
interpolator: str | None = "linear",
@@ -97,7 +98,7 @@ def __init__(
9798
x1 = x[0]
9899
else:
99100
x1 = x
100-
x: list[npt.NDArray] = [x] # type: ignore[no-redef]
101+
x: list[npt.NDArray[np.float64]] = [x] # type: ignore[no-redef]
101102
x2 = None
102103
if x1.shape[0] != y.shape[0]:
103104
raise ValueError(

src/pybamm/expression_tree/matrix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66
import numpy.typing as npt
77
from scipy.sparse import csr_matrix, issparse
8+
from typing import Any
89

910
import pybamm
1011
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
@@ -17,7 +18,7 @@ class Matrix(pybamm.Array):
1718

1819
def __init__(
1920
self,
20-
entries: npt.NDArray | list[float] | csr_matrix,
21+
entries: npt.NDArray[Any] | list[float] | csr_matrix,
2122
name: str | None = None,
2223
domain: DomainType = None,
2324
auxiliary_domains: AuxiliaryDomainType = None,

src/pybamm/expression_tree/scalar.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ def set_id(self):
6767
def _base_evaluate(
6868
self,
6969
t: float | None = None,
70-
y: npt.NDArray | None = None,
71-
y_dot: npt.NDArray | None = None,
70+
y: npt.NDArray[np.float64] | None = None,
71+
y_dot: npt.NDArray[np.float64] | None = None,
7272
inputs: dict | str | None = None,
7373
):
7474
"""See :meth:`pybamm.Symbol._base_evaluate()`."""

src/pybamm/expression_tree/state_vector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ def __init__(
282282
def _base_evaluate(
283283
self,
284284
t: float | None = None,
285-
y: npt.NDArray | None = None,
286-
y_dot: npt.NDArray | None = None,
285+
y: npt.NDArray[np.float64] | None = None,
286+
y_dot: npt.NDArray[np.float64] | None = None,
287287
inputs: dict | str | None = None,
288288
):
289289
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
@@ -366,8 +366,8 @@ def __init__(
366366
def _base_evaluate(
367367
self,
368368
t: float | None = None,
369-
y: npt.NDArray | None = None,
370-
y_dot: npt.NDArray | None = None,
369+
y: npt.NDArray[np.float64] | None = None,
370+
y_dot: npt.NDArray[np.float64] | None = None,
371371
inputs: dict | str | None = None,
372372
):
373373
"""See :meth:`pybamm.Symbol._base_evaluate()`."""

src/pybamm/expression_tree/symbol.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -770,8 +770,8 @@ def _jac(self, variable):
770770
def _base_evaluate(
771771
self,
772772
t: float | None = None,
773-
y: npt.NDArray | None = None,
774-
y_dot: npt.NDArray | None = None,
773+
y: npt.NDArray[np.float64] | None = None,
774+
y_dot: npt.NDArray[np.float64] | None = None,
775775
inputs: dict | str | None = None,
776776
):
777777
"""
@@ -802,8 +802,8 @@ def _base_evaluate(
802802
def evaluate(
803803
self,
804804
t: float | None = None,
805-
y: npt.NDArray | None = None,
806-
y_dot: npt.NDArray | None = None,
805+
y: npt.NDArray[np.float64] | None = None,
806+
y_dot: npt.NDArray[np.float64] | None = None,
807807
inputs: dict | str | None = None,
808808
) -> ChildValue:
809809
"""Evaluate expression tree (wrapper to allow using dict of known values).

src/pybamm/expression_tree/unary_operators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def _unary_evaluate(self, child):
9494
def evaluate(
9595
self,
9696
t: float | None = None,
97-
y: npt.NDArray | None = None,
98-
y_dot: npt.NDArray | None = None,
97+
y: npt.NDArray[np.float64] | None = None,
98+
y_dot: npt.NDArray[np.float64] | None = None,
9999
inputs: dict | str | None = None,
100100
):
101101
"""See :meth:`pybamm.Symbol.evaluate()`."""

src/pybamm/expression_tree/vector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy.typing as npt
77
import pybamm
88
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
9+
from typing import Any
910

1011

1112
class Vector(pybamm.Array):
@@ -15,7 +16,7 @@ class Vector(pybamm.Array):
1516

1617
def __init__(
1718
self,
18-
entries: npt.NDArray | list[float] | np.matrix,
19+
entries: npt.NDArray[Any] | list[float] | np.matrix,
1920
name: str | None = None,
2021
domain: DomainType = None,
2122
auxiliary_domains: AuxiliaryDomainType = None,

src/pybamm/models/event.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from enum import Enum
44
import numpy.typing as npt
55
from typing import TypeVar
6+
import numpy as np
67

78

89
class EventType(Enum):
@@ -74,8 +75,8 @@ def _from_json(cls: type[E], snippet: dict) -> E:
7475
def evaluate(
7576
self,
7677
t: float | None = None,
77-
y: npt.NDArray | None = None,
78-
y_dot: npt.NDArray | None = None,
78+
y: npt.NDArray[np.float64] | None = None,
79+
y_dot: npt.NDArray[np.float64] | None = None,
7980
inputs: dict | None = None,
8081
):
8182
"""

src/pybamm/solvers/idaklu_jax.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from __future__ import annotations
12
import pybamm
23
import numpy as np
34
import numpy.typing as npt
45
import logging
56
import warnings
67
import numbers
78
import pybammsolvers.idaklu as idaklu
8-
from typing import Union
99

1010
from functools import lru_cache
1111

@@ -259,9 +259,9 @@ def f_isolated(*args, **kwargs):
259259

260260
def jax_value(
261261
self,
262-
t: npt.NDArray = None,
263-
inputs: Union[dict, None] = None,
264-
output_variables: Union[list[str], None] = None,
262+
t: npt.NDArray[np.float64] | None = None,
263+
inputs: dict | None = None,
264+
output_variables: list[str] | None = None,
265265
):
266266
"""Helper function to compute the gradient of a jaxified expression
267267
@@ -292,9 +292,9 @@ def jax_value(
292292

293293
def jax_grad(
294294
self,
295-
t: npt.NDArray = None,
296-
inputs: Union[dict, None] = None,
297-
output_variables: Union[list[str], None] = None,
295+
t: npt.NDArray[np.float64] | None = None,
296+
inputs: dict | None = None,
297+
output_variables: list[str] | None = None,
298298
):
299299
"""Helper function to compute the gradient of a jaxified expression
300300
@@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array):
396396

397397
def _jax_solve(
398398
self,
399-
t: Union[float, npt.NDArray],
399+
t: float | npt.NDArray[np.float64],
400400
*inputs,
401-
) -> npt.NDArray:
401+
) -> npt.NDArray[np.float64]:
402402
"""Solver implementation used by f-bind"""
403403
logger.info("jax_solve")
404404
logger.debug(f" t: {type(t)}, {t}")
@@ -410,7 +410,7 @@ def _jax_solve(
410410

411411
def _jax_jvp_impl(
412412
self,
413-
*args: Union[npt.NDArray],
413+
*args: npt.NDArray[np.float64],
414414
):
415415
"""JVP implementation used by f_jvp bind"""
416416
primals = args[: len(args) // 2]
@@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs(
455455

456456
def _jax_vjp_impl(
457457
self,
458-
y_bar: npt.NDArray,
459-
invar: Union[str, int], # index or name of input variable
460-
*primals: npt.NDArray,
458+
y_bar: npt.NDArray[np.float64],
459+
invar: str | int, # index or name of input variable
460+
*primals: npt.NDArray[np.float64],
461461
):
462462
"""VJP implementation used by f_vjp bind"""
463463
logger.info("py:f_vjp_p_impl")

src/pybamm/solvers/processed_variable_time_integral.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1+
from __future__ import annotations
12
from dataclasses import dataclass
2-
from typing import Literal, Optional, Union
3+
from typing import Literal
34
import numpy.typing as npt
5+
import numpy as np
46
import pybamm
57

68

79
@dataclass
810
class ProcessedVariableTimeIntegral:
911
method: Literal["discrete", "continuous"]
10-
initial_condition: npt.NDArray
11-
discrete_times: Optional[npt.NDArray]
12+
initial_condition: npt.NDArray[np.float64] | None | float
13+
discrete_times: npt.NDArray[np.float64] | None
1214

1315
@staticmethod
1416
def from_pybamm_var(
15-
var: Union[pybamm.DiscreteTimeSum, pybamm.ExplicitTimeIntegral],
16-
) -> "ProcessedVariableTimeIntegral":
17+
var: pybamm.DiscreteTimeSum | pybamm.ExplicitTimeIntegral,
18+
) -> ProcessedVariableTimeIntegral:
1719
if isinstance(var, pybamm.DiscreteTimeSum):
1820
return ProcessedVariableTimeIntegral(
1921
method="discrete", initial_condition=0.0, discrete_times=var.sum_times

0 commit comments

Comments
 (0)