diff --git a/benchmarks/different_model_options.py b/benchmarks/different_model_options.py index 6fdd6c7264..77212b9c1c 100644 --- a/benchmarks/different_model_options.py +++ b/benchmarks/different_model_options.py @@ -34,8 +34,8 @@ def build_model(parameter, model_, option, value): class SolveModel: solver: pybamm.BaseSolver model: pybamm.BaseModel - t_eval: npt.NDArray - t_interp: npt.NDArray | None + t_eval: npt.NDArray[np.float64] + t_interp: npt.NDArray[np.float64] | None def solve_setup(self, parameter, model_, option, value, solver_class): self.solver = solver_class() diff --git a/benchmarks/time_solve_models.py b/benchmarks/time_solve_models.py index 3aafca797f..42435acf18 100644 --- a/benchmarks/time_solve_models.py +++ b/benchmarks/time_solve_models.py @@ -31,8 +31,8 @@ class TimeSolveSPM: ) model: pybamm.BaseModel solver: pybamm.BaseSolver - t_eval: npt.NDArray - t_interp: npt.NDArray | None + t_eval: npt.NDArray[np.float64] + t_interp: npt.NDArray[np.float64] | None def setup(self, solve_first, parameters, solver_class): set_random_seed() @@ -97,7 +97,7 @@ class TimeSolveSPMe: ) model: pybamm.BaseModel solver: pybamm.BaseSolver - t_eval: npt.NDArray + t_eval: npt.NDArray[np.float64] def setup(self, solve_first, parameters, solver_class): set_random_seed() @@ -161,7 +161,7 @@ class TimeSolveDFN: ) model: pybamm.BaseModel solver: pybamm.BaseSolver - t_eval: npt.NDArray + t_eval: npt.NDArray[np.float64] def setup(self, solve_first, parameters, solver_class): set_random_seed() diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index efd9874664..c7ac32041d 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -153,8 +153,8 @@ def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" diff --git a/src/pybamm/expression_tree/concatenations.py b/src/pybamm/expression_tree/concatenations.py index 6f190aaf59..5ad32493ca 100644 --- a/src/pybamm/expression_tree/concatenations.py +++ b/src/pybamm/expression_tree/concatenations.py @@ -12,6 +12,7 @@ import sympy from scipy.sparse import issparse, vstack from collections.abc import Sequence +from typing import Any import pybamm @@ -114,7 +115,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]): return domains - def _concatenation_evaluate(self, children_eval: list[npt.NDArray]): + def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]): """See :meth:`Concatenation._concatenation_evaluate()`.""" if len(children_eval) == 0: return np.array([]) @@ -124,8 +125,8 @@ def _concatenation_evaluate(self, children_eval: list[npt.NDArray]): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" @@ -369,7 +370,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict: start = end return slices - def _concatenation_evaluate(self, children_eval: list[npt.NDArray]): + def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]): """See :meth:`Concatenation._concatenation_evaluate()`.""" # preallocate vector vector = np.empty((self._size, 1)) diff --git a/src/pybamm/expression_tree/discrete_time_sum.py b/src/pybamm/expression_tree/discrete_time_sum.py index 977bdd2fad..715f527602 100644 --- a/src/pybamm/expression_tree/discrete_time_sum.py +++ b/src/pybamm/expression_tree/discrete_time_sum.py @@ -1,5 +1,7 @@ import pybamm import numpy.typing as npt +import numpy as np +from typing import Any class DiscreteTimeData(pybamm.Interpolant): @@ -19,7 +21,9 @@ class DiscreteTimeData(pybamm.Interpolant): """ - def __init__(self, time_points: npt.NDArray, data: npt.NDArray, name: str): + def __init__( + self, time_points: npt.NDArray[np.float64], data: npt.NDArray[Any], name: str + ): super().__init__(time_points, data, pybamm.t, name) def create_copy(self, new_children=None, perform_simplifications=True): diff --git a/src/pybamm/expression_tree/functions.py b/src/pybamm/expression_tree/functions.py index a5a999a092..8e7fcf94e5 100644 --- a/src/pybamm/expression_tree/functions.py +++ b/src/pybamm/expression_tree/functions.py @@ -123,8 +123,8 @@ def _function_jac(self, children_jacs): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" diff --git a/src/pybamm/expression_tree/independent_variable.py b/src/pybamm/expression_tree/independent_variable.py index 85554435a5..672b89d322 100644 --- a/src/pybamm/expression_tree/independent_variable.py +++ b/src/pybamm/expression_tree/independent_variable.py @@ -3,6 +3,7 @@ # from __future__ import annotations import sympy +import numpy as np import numpy.typing as npt import pybamm from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType @@ -93,8 +94,8 @@ def create_copy( def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" diff --git a/src/pybamm/expression_tree/input_parameter.py b/src/pybamm/expression_tree/input_parameter.py index 57963e9d62..e82621f63d 100644 --- a/src/pybamm/expression_tree/input_parameter.py +++ b/src/pybamm/expression_tree/input_parameter.py @@ -89,8 +89,8 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix: def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): # inputs should be a dictionary diff --git a/src/pybamm/expression_tree/interpolant.py b/src/pybamm/expression_tree/interpolant.py index d8e1afa2da..31324a02bf 100644 --- a/src/pybamm/expression_tree/interpolant.py +++ b/src/pybamm/expression_tree/interpolant.py @@ -7,6 +7,7 @@ from scipy import interpolate from collections.abc import Sequence import numbers +from typing import Any import pybamm @@ -44,8 +45,8 @@ class Interpolant(pybamm.Function): def __init__( self, - x: npt.NDArray | Sequence[npt.NDArray], - y: npt.NDArray, + x: npt.NDArray[np.float64] | Sequence[npt.NDArray[np.float64]], + y: npt.NDArray[Any], children: Sequence[pybamm.Symbol] | pybamm.Time, name: str | None = None, interpolator: str | None = "linear", @@ -97,7 +98,7 @@ def __init__( x1 = x[0] else: x1 = x - x: list[npt.NDArray] = [x] # type: ignore[no-redef] + x: list[npt.NDArray[np.float64]] = [x] # type: ignore[no-redef] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( diff --git a/src/pybamm/expression_tree/matrix.py b/src/pybamm/expression_tree/matrix.py index af44c39a16..9af22a4895 100644 --- a/src/pybamm/expression_tree/matrix.py +++ b/src/pybamm/expression_tree/matrix.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt from scipy.sparse import csr_matrix, issparse +from typing import Any import pybamm from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType @@ -17,7 +18,7 @@ class Matrix(pybamm.Array): def __init__( self, - entries: npt.NDArray | list[float] | csr_matrix, + entries: npt.NDArray[Any] | list[float] | csr_matrix, name: str | None = None, domain: DomainType = None, auxiliary_domains: AuxiliaryDomainType = None, diff --git a/src/pybamm/expression_tree/scalar.py b/src/pybamm/expression_tree/scalar.py index c06a099da0..c2680c6711 100644 --- a/src/pybamm/expression_tree/scalar.py +++ b/src/pybamm/expression_tree/scalar.py @@ -67,8 +67,8 @@ def set_id(self): def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" diff --git a/src/pybamm/expression_tree/state_vector.py b/src/pybamm/expression_tree/state_vector.py index c4f5649efc..a66d2665f9 100644 --- a/src/pybamm/expression_tree/state_vector.py +++ b/src/pybamm/expression_tree/state_vector.py @@ -282,8 +282,8 @@ def __init__( def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" @@ -366,8 +366,8 @@ def __init__( def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" diff --git a/src/pybamm/expression_tree/symbol.py b/src/pybamm/expression_tree/symbol.py index 34ca9d627b..c970c04e1b 100644 --- a/src/pybamm/expression_tree/symbol.py +++ b/src/pybamm/expression_tree/symbol.py @@ -770,8 +770,8 @@ def _jac(self, variable): def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """ @@ -802,8 +802,8 @@ def _base_evaluate( def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ) -> ChildValue: """Evaluate expression tree (wrapper to allow using dict of known values). diff --git a/src/pybamm/expression_tree/unary_operators.py b/src/pybamm/expression_tree/unary_operators.py index 2f998c47d6..bc7a9168de 100644 --- a/src/pybamm/expression_tree/unary_operators.py +++ b/src/pybamm/expression_tree/unary_operators.py @@ -94,8 +94,8 @@ def _unary_evaluate(self, child): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" diff --git a/src/pybamm/expression_tree/vector.py b/src/pybamm/expression_tree/vector.py index d9dcb3ef89..74fe8997b6 100644 --- a/src/pybamm/expression_tree/vector.py +++ b/src/pybamm/expression_tree/vector.py @@ -6,6 +6,7 @@ import numpy.typing as npt import pybamm from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType +from typing import Any class Vector(pybamm.Array): @@ -15,7 +16,7 @@ class Vector(pybamm.Array): def __init__( self, - entries: npt.NDArray | list[float] | np.matrix, + entries: npt.NDArray[Any] | list[float] | np.matrix, name: str | None = None, domain: DomainType = None, auxiliary_domains: AuxiliaryDomainType = None, diff --git a/src/pybamm/models/event.py b/src/pybamm/models/event.py index a52148e8c0..61c4c47e91 100644 --- a/src/pybamm/models/event.py +++ b/src/pybamm/models/event.py @@ -3,6 +3,7 @@ from enum import Enum import numpy.typing as npt from typing import TypeVar +import numpy as np class EventType(Enum): @@ -74,8 +75,8 @@ def _from_json(cls: type[E], snippet: dict) -> E: def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | None = None, ): """ diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index ef505570fa..701a4d1dfa 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -1,3 +1,4 @@ +from __future__ import annotations import pybamm import numpy as np import numpy.typing as npt @@ -5,7 +6,6 @@ import warnings import numbers import pybammsolvers.idaklu as idaklu -from typing import Union from functools import lru_cache @@ -259,9 +259,9 @@ def f_isolated(*args, **kwargs): def jax_value( self, - t: npt.NDArray = None, - inputs: Union[dict, None] = None, - output_variables: Union[list[str], None] = None, + t: npt.NDArray[np.float64] | None = None, + inputs: dict | None = None, + output_variables: list[str] | None = None, ): """Helper function to compute the gradient of a jaxified expression @@ -292,9 +292,9 @@ def jax_value( def jax_grad( self, - t: npt.NDArray = None, - inputs: Union[dict, None] = None, - output_variables: Union[list[str], None] = None, + t: npt.NDArray[np.float64] | None = None, + inputs: dict | None = None, + output_variables: list[str] | None = None, ): """Helper function to compute the gradient of a jaxified expression @@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array): def _jax_solve( self, - t: Union[float, npt.NDArray], + t: float | npt.NDArray[np.float64], *inputs, - ) -> npt.NDArray: + ) -> npt.NDArray[np.float64]: """Solver implementation used by f-bind""" logger.info("jax_solve") logger.debug(f" t: {type(t)}, {t}") @@ -410,7 +410,7 @@ def _jax_solve( def _jax_jvp_impl( self, - *args: Union[npt.NDArray], + *args: npt.NDArray[np.float64], ): """JVP implementation used by f_jvp bind""" primals = args[: len(args) // 2] @@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs( def _jax_vjp_impl( self, - y_bar: npt.NDArray, - invar: Union[str, int], # index or name of input variable - *primals: npt.NDArray, + y_bar: npt.NDArray[np.float64], + invar: str | int, # index or name of input variable + *primals: npt.NDArray[np.float64], ): """VJP implementation used by f_vjp bind""" logger.info("py:f_vjp_p_impl") diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index ce41c1796e..a7b3bd7560 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -1,19 +1,21 @@ +from __future__ import annotations from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Literal import numpy.typing as npt +import numpy as np import pybamm @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: npt.NDArray - discrete_times: Optional[npt.NDArray] + initial_condition: npt.NDArray[np.float64] | None | float + discrete_times: npt.NDArray[np.float64] | None @staticmethod def from_pybamm_var( - var: Union[pybamm.DiscreteTimeSum, pybamm.ExplicitTimeIntegral], - ) -> "ProcessedVariableTimeIntegral": + var: pybamm.DiscreteTimeSum | pybamm.ExplicitTimeIntegral, + ) -> ProcessedVariableTimeIntegral: if isinstance(var, pybamm.DiscreteTimeSum): return ProcessedVariableTimeIntegral( method="discrete", initial_condition=0.0, discrete_times=var.sum_times diff --git a/src/pybamm/type_definitions.py b/src/pybamm/type_definitions.py index cd48fa4c81..bb7924e3d2 100644 --- a/src/pybamm/type_definitions.py +++ b/src/pybamm/type_definitions.py @@ -5,13 +5,14 @@ import numpy as np import numpy.typing as npt import pybamm +from typing import Any # numbers.Number should not be used for type hints Numeric: TypeAlias = Union[int, float, np.number] # expression tree -ChildValue: TypeAlias = Union[float, npt.NDArray] -ChildSymbol: TypeAlias = Union[float, npt.NDArray, pybamm.Symbol] +ChildValue: TypeAlias = Union[float, npt.NDArray[Any]] +ChildSymbol: TypeAlias = Union[float, npt.NDArray[Any], pybamm.Symbol] DomainType: TypeAlias = Union[list[str], str, None] AuxiliaryDomainType: TypeAlias = Union[dict[str, str], None]