Skip to content

Make npt.NDArray type hints more specific with dtype #4901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/different_model_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def build_model(parameter, model_, option, value):
class SolveModel:
solver: pybamm.BaseSolver
model: pybamm.BaseModel
t_eval: npt.NDArray
t_interp: npt.NDArray | None
t_eval: npt.NDArray[np.float64]
t_interp: npt.NDArray[np.float64] | None

def solve_setup(self, parameter, model_, option, value, solver_class):
self.solver = solver_class()
Expand Down
8 changes: 4 additions & 4 deletions benchmarks/time_solve_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class TimeSolveSPM:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: npt.NDArray
t_interp: npt.NDArray | None
t_eval: npt.NDArray[np.float64]
t_interp: npt.NDArray[np.float64] | None

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down Expand Up @@ -97,7 +97,7 @@ class TimeSolveSPMe:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: npt.NDArray
t_eval: npt.NDArray[np.float64]

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down Expand Up @@ -161,7 +161,7 @@ class TimeSolveDFN:
)
model: pybamm.BaseModel
solver: pybamm.BaseSolver
t_eval: npt.NDArray
t_eval: npt.NDArray[np.float64]

def setup(self, solve_first, parameters, solver_class):
set_random_seed()
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol):
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
9 changes: 5 additions & 4 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sympy
from scipy.sparse import issparse, vstack
from collections.abc import Sequence
from typing import Any

import pybamm

Expand Down Expand Up @@ -114,7 +115,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]):

return domains

def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
if len(children_eval) == 0:
return np.array([])
Expand All @@ -124,8 +125,8 @@ def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down Expand Up @@ -369,7 +370,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict:
start = end
return slices

def _concatenation_evaluate(self, children_eval: list[npt.NDArray]):
def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]):
"""See :meth:`Concatenation._concatenation_evaluate()`."""
# preallocate vector
vector = np.empty((self._size, 1))
Expand Down
6 changes: 5 additions & 1 deletion src/pybamm/expression_tree/discrete_time_sum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pybamm
import numpy.typing as npt
import numpy as np
from typing import Any


class DiscreteTimeData(pybamm.Interpolant):
Expand All @@ -19,7 +21,9 @@ class DiscreteTimeData(pybamm.Interpolant):

"""

def __init__(self, time_points: npt.NDArray, data: npt.NDArray, name: str):
def __init__(
self, time_points: npt.NDArray[np.float64], data: npt.NDArray[Any], name: str
):
super().__init__(time_points, data, pybamm.t, name)

def create_copy(self, new_children=None, perform_simplifications=True):
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ def _function_jac(self, children_jacs):
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/independent_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
from __future__ import annotations
import sympy
import numpy as np
import numpy.typing as npt
import pybamm
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
Expand Down Expand Up @@ -93,8 +94,8 @@ def create_copy(
def _base_evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/input_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix:
def _base_evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
# inputs should be a dictionary
Expand Down
7 changes: 4 additions & 3 deletions src/pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scipy import interpolate
from collections.abc import Sequence
import numbers
from typing import Any

import pybamm

Expand Down Expand Up @@ -44,8 +45,8 @@ class Interpolant(pybamm.Function):

def __init__(
self,
x: npt.NDArray | Sequence[npt.NDArray],
y: npt.NDArray,
x: npt.NDArray[np.float64] | Sequence[npt.NDArray[np.float64]],
y: npt.NDArray[Any],
children: Sequence[pybamm.Symbol] | pybamm.Time,
name: str | None = None,
interpolator: str | None = "linear",
Expand Down Expand Up @@ -97,7 +98,7 @@ def __init__(
x1 = x[0]
else:
x1 = x
x: list[npt.NDArray] = [x] # type: ignore[no-redef]
x: list[npt.NDArray[np.float64]] = [x] # type: ignore[no-redef]
x2 = None
if x1.shape[0] != y.shape[0]:
raise ValueError(
Expand Down
3 changes: 2 additions & 1 deletion src/pybamm/expression_tree/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import numpy.typing as npt
from scipy.sparse import csr_matrix, issparse
from typing import Any

import pybamm
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
Expand All @@ -17,7 +18,7 @@ class Matrix(pybamm.Array):

def __init__(
self,
entries: npt.NDArray | list[float] | csr_matrix,
entries: npt.NDArray[Any] | list[float] | csr_matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def set_id(self):
def _base_evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
8 changes: 4 additions & 4 deletions src/pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ def __init__(
def _base_evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down Expand Up @@ -366,8 +366,8 @@ def __init__(
def _base_evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol._base_evaluate()`."""
Expand Down
8 changes: 4 additions & 4 deletions src/pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,8 @@ def _jac(self, variable):
def _base_evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""
Expand Down Expand Up @@ -802,8 +802,8 @@ def _base_evaluate(
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
) -> ChildValue:
"""Evaluate expression tree (wrapper to allow using dict of known values).
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _unary_evaluate(self, child):
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | str | None = None,
):
"""See :meth:`pybamm.Symbol.evaluate()`."""
Expand Down
3 changes: 2 additions & 1 deletion src/pybamm/expression_tree/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy.typing as npt
import pybamm
from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType
from typing import Any


class Vector(pybamm.Array):
Expand All @@ -15,7 +16,7 @@ class Vector(pybamm.Array):

def __init__(
self,
entries: npt.NDArray | list[float] | np.matrix,
entries: npt.NDArray[Any] | list[float] | np.matrix,
name: str | None = None,
domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
import numpy.typing as npt
from typing import TypeVar
import numpy as np


class EventType(Enum):
Expand Down Expand Up @@ -74,8 +75,8 @@ def _from_json(cls: type[E], snippet: dict) -> E:
def evaluate(
self,
t: float | None = None,
y: npt.NDArray | None = None,
y_dot: npt.NDArray | None = None,
y: npt.NDArray[np.float64] | None = None,
y_dot: npt.NDArray[np.float64] | None = None,
inputs: dict | None = None,
):
"""
Expand Down
26 changes: 13 additions & 13 deletions src/pybamm/solvers/idaklu_jax.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations
import pybamm
import numpy as np
import numpy.typing as npt
import logging
import warnings
import numbers
import pybammsolvers.idaklu as idaklu
from typing import Union

from functools import lru_cache

Expand Down Expand Up @@ -259,9 +259,9 @@ def f_isolated(*args, **kwargs):

def jax_value(
self,
t: npt.NDArray = None,
inputs: Union[dict, None] = None,
output_variables: Union[list[str], None] = None,
t: npt.NDArray[np.float64] | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression

Expand Down Expand Up @@ -292,9 +292,9 @@ def jax_value(

def jax_grad(
self,
t: npt.NDArray = None,
inputs: Union[dict, None] = None,
output_variables: Union[list[str], None] = None,
t: npt.NDArray[np.float64] | None = None,
inputs: dict | None = None,
output_variables: list[str] | None = None,
):
"""Helper function to compute the gradient of a jaxified expression

Expand Down Expand Up @@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array):

def _jax_solve(
self,
t: Union[float, npt.NDArray],
t: float | npt.NDArray[np.float64],
*inputs,
) -> npt.NDArray:
) -> npt.NDArray[np.float64]:
"""Solver implementation used by f-bind"""
logger.info("jax_solve")
logger.debug(f" t: {type(t)}, {t}")
Expand All @@ -410,7 +410,7 @@ def _jax_solve(

def _jax_jvp_impl(
self,
*args: Union[npt.NDArray],
*args: npt.NDArray[np.float64],
):
"""JVP implementation used by f_jvp bind"""
primals = args[: len(args) // 2]
Expand Down Expand Up @@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs(

def _jax_vjp_impl(
self,
y_bar: npt.NDArray,
invar: Union[str, int], # index or name of input variable
*primals: npt.NDArray,
y_bar: npt.NDArray[np.float64],
invar: str | int, # index or name of input variable
*primals: npt.NDArray[np.float64],
):
"""VJP implementation used by f_vjp bind"""
logger.info("py:f_vjp_p_impl")
Expand Down
12 changes: 7 additions & 5 deletions src/pybamm/solvers/processed_variable_time_integral.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Optional, Union
from typing import Literal
import numpy.typing as npt
import numpy as np
import pybamm


@dataclass
class ProcessedVariableTimeIntegral:
method: Literal["discrete", "continuous"]
initial_condition: npt.NDArray
discrete_times: Optional[npt.NDArray]
initial_condition: npt.NDArray[np.float64] | None | float
discrete_times: npt.NDArray[np.float64] | None

@staticmethod
def from_pybamm_var(
var: Union[pybamm.DiscreteTimeSum, pybamm.ExplicitTimeIntegral],
) -> "ProcessedVariableTimeIntegral":
var: pybamm.DiscreteTimeSum | pybamm.ExplicitTimeIntegral,
) -> ProcessedVariableTimeIntegral:
if isinstance(var, pybamm.DiscreteTimeSum):
return ProcessedVariableTimeIntegral(
method="discrete", initial_condition=0.0, discrete_times=var.sum_times
Expand Down
Loading