From cdb7be86251adfc944f65d0cca98db29f971e700 Mon Sep 17 00:00:00 2001 From: Vidip Singh <112854574+vidipsingh@users.noreply.github.com> Date: Sat, 8 Mar 2025 23:52:59 +0530 Subject: [PATCH 1/9] Make npt.NDArray type hints more specific with dtype --- benchmarks/different_model_options.py | 4 ++-- benchmarks/time_solve_models.py | 8 ++++---- src/pybamm/expression_tree/binary_operators.py | 4 ++-- src/pybamm/expression_tree/concatenations.py | 9 +++++---- src/pybamm/expression_tree/discrete_time_sum.py | 6 +++++- src/pybamm/expression_tree/functions.py | 4 ++-- src/pybamm/expression_tree/independent_variable.py | 5 +++-- src/pybamm/expression_tree/input_parameter.py | 4 ++-- src/pybamm/expression_tree/interpolant.py | 7 ++++--- src/pybamm/expression_tree/matrix.py | 3 ++- src/pybamm/expression_tree/scalar.py | 4 ++-- src/pybamm/expression_tree/state_vector.py | 8 ++++---- src/pybamm/expression_tree/symbol.py | 8 ++++---- src/pybamm/expression_tree/unary_operators.py | 4 ++-- src/pybamm/expression_tree/vector.py | 3 ++- src/pybamm/models/event.py | 5 +++-- src/pybamm/solvers/idaklu_jax.py | 14 +++++++------- .../solvers/processed_variable_time_integral.py | 5 +++-- src/pybamm/type_definitions.py | 5 +++-- 19 files changed, 61 insertions(+), 49 deletions(-) diff --git a/benchmarks/different_model_options.py b/benchmarks/different_model_options.py index 6fdd6c7264..77212b9c1c 100644 --- a/benchmarks/different_model_options.py +++ b/benchmarks/different_model_options.py @@ -34,8 +34,8 @@ def build_model(parameter, model_, option, value): class SolveModel: solver: pybamm.BaseSolver model: pybamm.BaseModel - t_eval: npt.NDArray - t_interp: npt.NDArray | None + t_eval: npt.NDArray[np.float64] + t_interp: npt.NDArray[np.float64] | None def solve_setup(self, parameter, model_, option, value, solver_class): self.solver = solver_class() diff --git a/benchmarks/time_solve_models.py b/benchmarks/time_solve_models.py index 3aafca797f..42435acf18 100644 --- a/benchmarks/time_solve_models.py +++ b/benchmarks/time_solve_models.py @@ -31,8 +31,8 @@ class TimeSolveSPM: ) model: pybamm.BaseModel solver: pybamm.BaseSolver - t_eval: npt.NDArray - t_interp: npt.NDArray | None + t_eval: npt.NDArray[np.float64] + t_interp: npt.NDArray[np.float64] | None def setup(self, solve_first, parameters, solver_class): set_random_seed() @@ -97,7 +97,7 @@ class TimeSolveSPMe: ) model: pybamm.BaseModel solver: pybamm.BaseSolver - t_eval: npt.NDArray + t_eval: npt.NDArray[np.float64] def setup(self, solve_first, parameters, solver_class): set_random_seed() @@ -161,7 +161,7 @@ class TimeSolveDFN: ) model: pybamm.BaseModel solver: pybamm.BaseSolver - t_eval: npt.NDArray + t_eval: npt.NDArray[np.float64] def setup(self, solve_first, parameters, solver_class): set_random_seed() diff --git a/src/pybamm/expression_tree/binary_operators.py b/src/pybamm/expression_tree/binary_operators.py index efd9874664..c7ac32041d 100644 --- a/src/pybamm/expression_tree/binary_operators.py +++ b/src/pybamm/expression_tree/binary_operators.py @@ -153,8 +153,8 @@ def _binary_new_copy(self, left: ChildSymbol, right: ChildSymbol): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" diff --git a/src/pybamm/expression_tree/concatenations.py b/src/pybamm/expression_tree/concatenations.py index dc30cf4b5e..6de5ce7826 100644 --- a/src/pybamm/expression_tree/concatenations.py +++ b/src/pybamm/expression_tree/concatenations.py @@ -11,6 +11,7 @@ import sympy from scipy.sparse import issparse, vstack from collections.abc import Sequence +from typing import Any import pybamm @@ -113,7 +114,7 @@ def get_children_domains(self, children: Sequence[pybamm.Symbol]): return domains - def _concatenation_evaluate(self, children_eval: list[npt.NDArray]): + def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]): """See :meth:`Concatenation._concatenation_evaluate()`.""" if len(children_eval) == 0: return np.array([]) @@ -123,8 +124,8 @@ def _concatenation_evaluate(self, children_eval: list[npt.NDArray]): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" @@ -368,7 +369,7 @@ def create_slices(self, node: pybamm.Symbol) -> defaultdict: start = end return slices - def _concatenation_evaluate(self, children_eval: list[npt.NDArray]): + def _concatenation_evaluate(self, children_eval: list[npt.NDArray[Any]]): """See :meth:`Concatenation._concatenation_evaluate()`.""" # preallocate vector vector = np.empty((self._size, 1)) diff --git a/src/pybamm/expression_tree/discrete_time_sum.py b/src/pybamm/expression_tree/discrete_time_sum.py index 977bdd2fad..715f527602 100644 --- a/src/pybamm/expression_tree/discrete_time_sum.py +++ b/src/pybamm/expression_tree/discrete_time_sum.py @@ -1,5 +1,7 @@ import pybamm import numpy.typing as npt +import numpy as np +from typing import Any class DiscreteTimeData(pybamm.Interpolant): @@ -19,7 +21,9 @@ class DiscreteTimeData(pybamm.Interpolant): """ - def __init__(self, time_points: npt.NDArray, data: npt.NDArray, name: str): + def __init__( + self, time_points: npt.NDArray[np.float64], data: npt.NDArray[Any], name: str + ): super().__init__(time_points, data, pybamm.t, name) def create_copy(self, new_children=None, perform_simplifications=True): diff --git a/src/pybamm/expression_tree/functions.py b/src/pybamm/expression_tree/functions.py index a5a999a092..8e7fcf94e5 100644 --- a/src/pybamm/expression_tree/functions.py +++ b/src/pybamm/expression_tree/functions.py @@ -123,8 +123,8 @@ def _function_jac(self, children_jacs): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" diff --git a/src/pybamm/expression_tree/independent_variable.py b/src/pybamm/expression_tree/independent_variable.py index 85554435a5..672b89d322 100644 --- a/src/pybamm/expression_tree/independent_variable.py +++ b/src/pybamm/expression_tree/independent_variable.py @@ -3,6 +3,7 @@ # from __future__ import annotations import sympy +import numpy as np import numpy.typing as npt import pybamm from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType @@ -93,8 +94,8 @@ def create_copy( def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" diff --git a/src/pybamm/expression_tree/input_parameter.py b/src/pybamm/expression_tree/input_parameter.py index 57963e9d62..e82621f63d 100644 --- a/src/pybamm/expression_tree/input_parameter.py +++ b/src/pybamm/expression_tree/input_parameter.py @@ -89,8 +89,8 @@ def _jac(self, variable: pybamm.StateVector) -> pybamm.Matrix: def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): # inputs should be a dictionary diff --git a/src/pybamm/expression_tree/interpolant.py b/src/pybamm/expression_tree/interpolant.py index d8e1afa2da..31324a02bf 100644 --- a/src/pybamm/expression_tree/interpolant.py +++ b/src/pybamm/expression_tree/interpolant.py @@ -7,6 +7,7 @@ from scipy import interpolate from collections.abc import Sequence import numbers +from typing import Any import pybamm @@ -44,8 +45,8 @@ class Interpolant(pybamm.Function): def __init__( self, - x: npt.NDArray | Sequence[npt.NDArray], - y: npt.NDArray, + x: npt.NDArray[np.float64] | Sequence[npt.NDArray[np.float64]], + y: npt.NDArray[Any], children: Sequence[pybamm.Symbol] | pybamm.Time, name: str | None = None, interpolator: str | None = "linear", @@ -97,7 +98,7 @@ def __init__( x1 = x[0] else: x1 = x - x: list[npt.NDArray] = [x] # type: ignore[no-redef] + x: list[npt.NDArray[np.float64]] = [x] # type: ignore[no-redef] x2 = None if x1.shape[0] != y.shape[0]: raise ValueError( diff --git a/src/pybamm/expression_tree/matrix.py b/src/pybamm/expression_tree/matrix.py index af44c39a16..9af22a4895 100644 --- a/src/pybamm/expression_tree/matrix.py +++ b/src/pybamm/expression_tree/matrix.py @@ -5,6 +5,7 @@ import numpy as np import numpy.typing as npt from scipy.sparse import csr_matrix, issparse +from typing import Any import pybamm from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType @@ -17,7 +18,7 @@ class Matrix(pybamm.Array): def __init__( self, - entries: npt.NDArray | list[float] | csr_matrix, + entries: npt.NDArray[Any] | list[float] | csr_matrix, name: str | None = None, domain: DomainType = None, auxiliary_domains: AuxiliaryDomainType = None, diff --git a/src/pybamm/expression_tree/scalar.py b/src/pybamm/expression_tree/scalar.py index c06a099da0..c2680c6711 100644 --- a/src/pybamm/expression_tree/scalar.py +++ b/src/pybamm/expression_tree/scalar.py @@ -67,8 +67,8 @@ def set_id(self): def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" diff --git a/src/pybamm/expression_tree/state_vector.py b/src/pybamm/expression_tree/state_vector.py index c4f5649efc..a66d2665f9 100644 --- a/src/pybamm/expression_tree/state_vector.py +++ b/src/pybamm/expression_tree/state_vector.py @@ -282,8 +282,8 @@ def __init__( def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" @@ -366,8 +366,8 @@ def __init__( def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol._base_evaluate()`.""" diff --git a/src/pybamm/expression_tree/symbol.py b/src/pybamm/expression_tree/symbol.py index 34ca9d627b..c970c04e1b 100644 --- a/src/pybamm/expression_tree/symbol.py +++ b/src/pybamm/expression_tree/symbol.py @@ -770,8 +770,8 @@ def _jac(self, variable): def _base_evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """ @@ -802,8 +802,8 @@ def _base_evaluate( def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ) -> ChildValue: """Evaluate expression tree (wrapper to allow using dict of known values). diff --git a/src/pybamm/expression_tree/unary_operators.py b/src/pybamm/expression_tree/unary_operators.py index 2f998c47d6..bc7a9168de 100644 --- a/src/pybamm/expression_tree/unary_operators.py +++ b/src/pybamm/expression_tree/unary_operators.py @@ -94,8 +94,8 @@ def _unary_evaluate(self, child): def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | str | None = None, ): """See :meth:`pybamm.Symbol.evaluate()`.""" diff --git a/src/pybamm/expression_tree/vector.py b/src/pybamm/expression_tree/vector.py index d9dcb3ef89..74fe8997b6 100644 --- a/src/pybamm/expression_tree/vector.py +++ b/src/pybamm/expression_tree/vector.py @@ -6,6 +6,7 @@ import numpy.typing as npt import pybamm from pybamm.type_definitions import DomainType, AuxiliaryDomainType, DomainsType +from typing import Any class Vector(pybamm.Array): @@ -15,7 +16,7 @@ class Vector(pybamm.Array): def __init__( self, - entries: npt.NDArray | list[float] | np.matrix, + entries: npt.NDArray[Any] | list[float] | np.matrix, name: str | None = None, domain: DomainType = None, auxiliary_domains: AuxiliaryDomainType = None, diff --git a/src/pybamm/models/event.py b/src/pybamm/models/event.py index a52148e8c0..61c4c47e91 100644 --- a/src/pybamm/models/event.py +++ b/src/pybamm/models/event.py @@ -3,6 +3,7 @@ from enum import Enum import numpy.typing as npt from typing import TypeVar +import numpy as np class EventType(Enum): @@ -74,8 +75,8 @@ def _from_json(cls: type[E], snippet: dict) -> E: def evaluate( self, t: float | None = None, - y: npt.NDArray | None = None, - y_dot: npt.NDArray | None = None, + y: npt.NDArray[np.float64] | None = None, + y_dot: npt.NDArray[np.float64] | None = None, inputs: dict | None = None, ): """ diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index ef505570fa..3b7b866712 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs): def jax_value( self, - t: npt.NDArray = None, + t: npt.NDArray[np.float64] = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -292,7 +292,7 @@ def jax_value( def jax_grad( self, - t: npt.NDArray = None, + t: npt.NDArray[np.float64] = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array): def _jax_solve( self, - t: Union[float, npt.NDArray], + t: Union[float, npt.NDArray[np.float64]], *inputs, - ) -> npt.NDArray: + ) -> npt.NDArray[np.float64]: """Solver implementation used by f-bind""" logger.info("jax_solve") logger.debug(f" t: {type(t)}, {t}") @@ -410,7 +410,7 @@ def _jax_solve( def _jax_jvp_impl( self, - *args: Union[npt.NDArray], + *args: Union[npt.NDArray[np.float64]], ): """JVP implementation used by f_jvp bind""" primals = args[: len(args) // 2] @@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs( def _jax_vjp_impl( self, - y_bar: npt.NDArray, + y_bar: npt.NDArray[np.float64], invar: Union[str, int], # index or name of input variable - *primals: npt.NDArray, + *primals: npt.NDArray[np.float64], ): """VJP implementation used by f_vjp bind""" logger.info("py:f_vjp_p_impl") diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index ce41c1796e..b6d588191c 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -1,14 +1,15 @@ from dataclasses import dataclass from typing import Literal, Optional, Union import numpy.typing as npt +import numpy as np import pybamm @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: npt.NDArray - discrete_times: Optional[npt.NDArray] + initial_condition: npt.NDArray[np.float64] + discrete_times: Optional[npt.NDArray[np.float64]] @staticmethod def from_pybamm_var( diff --git a/src/pybamm/type_definitions.py b/src/pybamm/type_definitions.py index cd48fa4c81..bb7924e3d2 100644 --- a/src/pybamm/type_definitions.py +++ b/src/pybamm/type_definitions.py @@ -5,13 +5,14 @@ import numpy as np import numpy.typing as npt import pybamm +from typing import Any # numbers.Number should not be used for type hints Numeric: TypeAlias = Union[int, float, np.number] # expression tree -ChildValue: TypeAlias = Union[float, npt.NDArray] -ChildSymbol: TypeAlias = Union[float, npt.NDArray, pybamm.Symbol] +ChildValue: TypeAlias = Union[float, npt.NDArray[Any]] +ChildSymbol: TypeAlias = Union[float, npt.NDArray[Any], pybamm.Symbol] DomainType: TypeAlias = Union[list[str], str, None] AuxiliaryDomainType: TypeAlias = Union[dict[str, str], None] From 863cbf3e8df3dbe87bfc4bd986ca86c44d9e981c Mon Sep 17 00:00:00 2001 From: Vidip Singh <112854574+vidipsingh@users.noreply.github.com> Date: Wed, 19 Mar 2025 23:00:49 +0530 Subject: [PATCH 2/9] Refine npt.NDArray type hints per review --- src/pybamm/solvers/idaklu_jax.py | 10 +++++----- src/pybamm/solvers/processed_variable_time_integral.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index 3b7b866712..b1fc448ff7 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs): def jax_value( self, - t: npt.NDArray[np.float64] = None, + t: npt.NDArray[np.float64] | None = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -292,7 +292,7 @@ def jax_value( def jax_grad( self, - t: npt.NDArray[np.float64] = None, + t: npt.NDArray[np.float64] | None = None, inputs: Union[dict, None] = None, output_variables: Union[list[str], None] = None, ): @@ -396,7 +396,7 @@ def _jax_solve_array_inputs(self, t, inputs_array): def _jax_solve( self, - t: Union[float, npt.NDArray[np.float64]], + t: float | npt.NDArray[np.float64], *inputs, ) -> npt.NDArray[np.float64]: """Solver implementation used by f-bind""" @@ -410,7 +410,7 @@ def _jax_solve( def _jax_jvp_impl( self, - *args: Union[npt.NDArray[np.float64]], + *args: npt.NDArray[np.float64], ): """JVP implementation used by f_jvp bind""" primals = args[: len(args) // 2] @@ -456,7 +456,7 @@ def _jax_jvp_impl_array_inputs( def _jax_vjp_impl( self, y_bar: npt.NDArray[np.float64], - invar: Union[str, int], # index or name of input variable + invar: str | int, # index or name of input variable *primals: npt.NDArray[np.float64], ): """VJP implementation used by f_vjp bind""" diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index b6d588191c..742e408c17 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal, Optional, Union +from typing import Literal, Union import numpy.typing as npt import numpy as np import pybamm @@ -8,8 +8,8 @@ @dataclass class ProcessedVariableTimeIntegral: method: Literal["discrete", "continuous"] - initial_condition: npt.NDArray[np.float64] - discrete_times: Optional[npt.NDArray[np.float64]] + initial_condition: npt.NDArray[np.float64] | None | float + discrete_times: npt.NDArray[np.float64] | None @staticmethod def from_pybamm_var( From c5dc8168fabe5f3aa81393a7f4da3f83548368d1 Mon Sep 17 00:00:00 2001 From: Vidip Singh <112854574+vidipsingh@users.noreply.github.com> Date: Sat, 29 Mar 2025 22:46:09 +0530 Subject: [PATCH 3/9] Update src/pybamm/solvers/idaklu_jax.py Co-authored-by: Saransh Chopra --- src/pybamm/solvers/idaklu_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index b1fc448ff7..fc251fe773 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -260,8 +260,8 @@ def f_isolated(*args, **kwargs): def jax_value( self, t: npt.NDArray[np.float64] | None = None, - inputs: Union[dict, None] = None, - output_variables: Union[list[str], None] = None, + inputs: dict | None = None, + output_variables: list[str] | None = None, ): """Helper function to compute the gradient of a jaxified expression From 2d97a165f1ccf645f5157bbb634789281612f085 Mon Sep 17 00:00:00 2001 From: Vidip Singh <112854574+vidipsingh@users.noreply.github.com> Date: Sat, 29 Mar 2025 22:46:34 +0530 Subject: [PATCH 4/9] Update src/pybamm/solvers/idaklu_jax.py Co-authored-by: Saransh Chopra --- src/pybamm/solvers/idaklu_jax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index fc251fe773..731170872f 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -293,8 +293,8 @@ def jax_value( def jax_grad( self, t: npt.NDArray[np.float64] | None = None, - inputs: Union[dict, None] = None, - output_variables: Union[list[str], None] = None, + inputs: dict | None = None, + output_variables: list[str] | None = None, ): """Helper function to compute the gradient of a jaxified expression From 2981b161b5f45b847ff01ffda4dbd602e452f2f9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 29 Mar 2025 17:16:45 +0000 Subject: [PATCH 5/9] style: pre-commit fixes --- src/pybamm/solvers/idaklu_jax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index 731170872f..ac717785b7 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -5,7 +5,6 @@ import warnings import numbers import pybammsolvers.idaklu as idaklu -from typing import Union from functools import lru_cache From ea1c4c7f035514549de4b9b42ad5986e6d21f9eb Mon Sep 17 00:00:00 2001 From: Vidip Singh <112854574+vidipsingh@users.noreply.github.com> Date: Sat, 5 Apr 2025 00:58:35 +0530 Subject: [PATCH 6/9] Add from __future__ import annotations to resolve CI type hint issue --- src/pybamm/solvers/processed_variable_time_integral.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index 742e408c17..b947ad7a27 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -1,3 +1,4 @@ +from __future__ import annotations from dataclasses import dataclass from typing import Literal, Union import numpy.typing as npt @@ -14,7 +15,7 @@ class ProcessedVariableTimeIntegral: @staticmethod def from_pybamm_var( var: Union[pybamm.DiscreteTimeSum, pybamm.ExplicitTimeIntegral], - ) -> "ProcessedVariableTimeIntegral": + ) -> ProcessedVariableTimeIntegral: if isinstance(var, pybamm.DiscreteTimeSum): return ProcessedVariableTimeIntegral( method="discrete", initial_condition=0.0, discrete_times=var.sum_times From e66c51c18a1f7b52b6a86d6a738981133051f897 Mon Sep 17 00:00:00 2001 From: Vidip Singh <112854574+vidipsingh@users.noreply.github.com> Date: Sat, 5 Apr 2025 23:25:17 +0530 Subject: [PATCH 7/9] Update src/pybamm/solvers/processed_variable_time_integral.py Co-authored-by: Saransh Chopra --- src/pybamm/solvers/processed_variable_time_integral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index b947ad7a27..d79440e8d8 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -14,7 +14,7 @@ class ProcessedVariableTimeIntegral: @staticmethod def from_pybamm_var( - var: Union[pybamm.DiscreteTimeSum, pybamm.ExplicitTimeIntegral], + var: pybamm.DiscreteTimeSum | pybamm.ExplicitTimeIntegral, ) -> ProcessedVariableTimeIntegral: if isinstance(var, pybamm.DiscreteTimeSum): return ProcessedVariableTimeIntegral( From 34422973c518f101110d8f9dca9dcfdbd9800134 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 5 Apr 2025 17:55:27 +0000 Subject: [PATCH 8/9] style: pre-commit fixes --- src/pybamm/solvers/processed_variable_time_integral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pybamm/solvers/processed_variable_time_integral.py b/src/pybamm/solvers/processed_variable_time_integral.py index d79440e8d8..a7b3bd7560 100644 --- a/src/pybamm/solvers/processed_variable_time_integral.py +++ b/src/pybamm/solvers/processed_variable_time_integral.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal, Union +from typing import Literal import numpy.typing as npt import numpy as np import pybamm From 3a4b938a2a77381fa9904465049c2832fe891cae Mon Sep 17 00:00:00 2001 From: vidipsingh01 Date: Mon, 28 Apr 2025 00:51:19 +0530 Subject: [PATCH 9/9] Add annotations import to idaklu_jax.py --- src/pybamm/solvers/idaklu_jax.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pybamm/solvers/idaklu_jax.py b/src/pybamm/solvers/idaklu_jax.py index ac717785b7..701a4d1dfa 100644 --- a/src/pybamm/solvers/idaklu_jax.py +++ b/src/pybamm/solvers/idaklu_jax.py @@ -1,3 +1,4 @@ +from __future__ import annotations import pybamm import numpy as np import numpy.typing as npt