Skip to content

Make npt.NDArray type hints more specific with dtype #4901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

vidipsingh
Copy link
Contributor

@vidipsingh vidipsingh commented Mar 8, 2025

Description

This PR refines npt.NDArray type hints in PyBaMM by adding explicit dtype (e.g., np.float64 for time/state arrays, Any for variable cases).

Fixes: #4900

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)

Important checks:

Please confirm the following before marking the PR as ready for review:

  • No style issues: nox -s pre-commit
  • All tests pass: nox -s tests
  • The documentation builds: nox -s doctests
  • Code is commented for hard-to-understand areas
  • Tests added that prove fix is effective or that feature works
mypy Output (Before):

src/pybamm/telemetry.py:18: error: Incompatible types in assignment (expression has type "Posthog", variable has type "MockTelemetry")  [assignment]
src/pybamm/telemetry.py:23: error: "MockTelemetry" has no attribute "log"  [attr-defined]
src/pybamm/config.py:168: error: Library stubs not installed for "yaml"  [import-untyped]
src/pybamm/config.py:168: note: Hint: "python3 -m pip install types-PyYAML"
src/pybamm/config.py:168: note: (or run "mypy --install-types" to install all missing stub packages)
src/pybamm/config.py:168: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports
src/pybamm/solvers/summary_variable.py:43: error: Need type annotation for "_variables" (hint: "_variables: dict[<type>, <type>] = ...")  [var-annotated]
src/pybamm/solvers/summary_variable.py:71: error: Incompatible types in assignment (expression has type "list[SummaryVariables]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:72: error: Argument 1 to "len" has incompatible type "None"; expected "Sized"  [arg-type]
src/pybamm/solvers/summary_variable.py:73: error: Value of type "None" is not indexable  [index]
src/pybamm/solvers/summary_variable.py:85: error: Cannot determine type of "_all_variables"  [has-type]
src/pybamm/solvers/summary_variable.py:103: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "_get_electrode_soh_sims_full"  [union-attr]
src/pybamm/solvers/summary_variable.py:105: error: Incompatible types in assignment (expression has type "list[Any]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:126: error: Incompatible return value type (got "ndarray[Any, dtype[Any]]", expected "float | list[float]")  [return-value]
src/pybamm/solvers/summary_variable.py:151: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:153: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:184: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "solve"  [union-attr]
src/pybamm/solvers/solution.py:160: error: Missing return statement  [return]
src/pybamm/solvers/processed_variable_time_integral.py:19: error: Argument "initial_condition" to "ProcessedVariableTimeIntegral" has incompatible type "float"; expected "ndarray[Any, dtype[Any]]"  [arg-type]
src/pybamm/solvers/idaklu_jax.py:262: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[Any]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:262: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:262: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/solvers/idaklu_jax.py:295: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[Any]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:295: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:295: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/expression_tree/functions.py:164: error: Expected iterable as variadic argument  [misc]
src/pybamm/expression_tree/functions.py:173: error: Argument 1 to "_function_new_copy" of "Function" has incompatible type "Symbol"; expected "list[Any]"  [arg-type]
src/pybamm/expression_tree/concatenations.py:479: error: Argument 1 to "intersect" has incompatible type "str | None"; expected "str"  [arg-type]
src/pybamm/expression_tree/concatenations.py:480: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:483: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:484: error: Value of type "str | None" is not indexable  [index]
src/pybamm/expression_tree/concatenations.py:545: error: Cannot determine type of "child"  [has-type]
src/pybamm/citations.py:34: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
src/pybamm/expression_tree/unary_operators.py:74: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/symbol.py:974: error: Incompatible return value type (got "list[Symbol]", expected "Symbol")  [return-value]
src/pybamm/expression_tree/symbol.py:996: error: Argument 2 to "Symbol" has incompatible type "Symbol"; expected "Sequence[Symbol] | None"  [arg-type]
src/pybamm/expression_tree/broadcasts.py:82: error: "Broadcast" has no attribute "broadcast_domain"  [attr-defined]
src/pybamm/expression_tree/binary_operators.py:131: error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]
src/pybamm/expression_tree/binary_operators.py:131: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/binary_operators.py:135: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/experiment/step/base_step.py:143: error: Argument 3 to "Interpolant" has incompatible type "Subtraction"; expected "Sequence[Symbol] | Time"  [arg-type]
src/pybamm/experiment/experiment.py:59: error: Incompatible types in assignment (expression has type "tuple[str | BaseStep]", variable has type "str | tuple[str] | BaseStep")  [assignment]
src/pybamm/experiment/experiment.py:62: error: Argument 1 to "len" has incompatible type "str | tuple[str] | BaseStep"; expected "Sized"  [arg-type]
src/pybamm/experiment/experiment.py:64: error: Item "BaseStep" of "str | tuple[str] | BaseStep" has no attribute "__iter__" (not iterable)  [union-attr]
src/pybamm/solvers/base_solver.py:97: error: Name "root_method" already defined on line 85  [no-redef]
src/pybamm/solvers/base_solver.py:97: error: "Callable[[BaseSolver], Any]" has no attribute "setter"  [attr-defined]
src/pybamm/solvers/base_solver.py:1124: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1125: error: Need type annotation for "initial_conditions"  [var-annotated]
src/pybamm/solvers/base_solver.py:1131: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1150: error: Incompatible return value type (got "list[Any]", expected "tuple[Any, ...]")  [return-value]
tests/unit/test_parameters/test_bpx.py:11: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
tests/unit/test_expression_tree/test_binary_operators.py:14: error: Need type annotation for "EMPTY_DOMAINS"  [var-annotated]
examples/scripts/run_ecmd.py:14: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/run_ecm.py:9: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/minimal_example_of_lookup_tables.py:37: error: Name "D_s_n" already defined on line 25  [no-redef]
examples/scripts/experiment_drive_cycle.py:32: error: List item 0 has incompatible type "tuple[str, str, str, Any, str, Any, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/SPMe_step.py:47: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPMe_step.py:49: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:59: error: Value of type "None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:60: error: Value of type "None" is not indexable  [index]
examples/scripts/MSMR.py:29: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/gitt.py:8: error: List item 0 has incompatible type "tuple[str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/cccv.py:10: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
docs/conf.py:110: error: Need type annotation for "suppress_warnings" (hint: "suppress_warnings: list[<type>] = ...")  [var-annotated]
docs/conf.py:202: error: Incompatible types in assignment (expression has type "bool", target has type "str")  [assignment]
docs/conf.py:221: error: Need type annotation for "latex_elements" (hint: "latex_elements: dict[<type>, <type>] = ...")  [var-annotated]
docs/conf.py:338: error: Incompatible types in assignment (expression has type "str | None", variable has type "str")  [assignment]
docs/conf.py:494: error: Dict entry 0 has incompatible type "str": "ParameterSets"; expected "str": "str"  [dict-item]
Found 59 errors in 26 files (checked 581 source files)

mypy Output (After):

src/pybamm/telemetry.py:18: error: Incompatible types in assignment (expression has type "Posthog", variable has type "MockTelemetry")  [assignment]
src/pybamm/telemetry.py:23: error: "MockTelemetry" has no attribute "log"  [attr-defined]
src/pybamm/config.py:168: error: Library stubs not installed for "yaml"  [import-untyped]
src/pybamm/config.py:168: note: Hint: "python3 -m pip install types-PyYAML"
src/pybamm/config.py:168: note: (or run "mypy --install-types" to install all missing stub packages)
src/pybamm/config.py:168: note: See https://mypy.readthedocs.io/en/stable/running_mypy.html#missing-imports
src/pybamm/solvers/summary_variable.py:43: error: Need type annotation for "_variables" (hint: "_variables: dict[<type>, <type>] = ...")  [var-annotated]
src/pybamm/solvers/summary_variable.py:71: error: Incompatible types in assignment (expression has type "list[SummaryVariables]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:72: error: Argument 1 to "len" has incompatible type "None"; expected "Sized"  [arg-type]
src/pybamm/solvers/summary_variable.py:73: error: Value of type "None" is not indexable  [index]
src/pybamm/solvers/summary_variable.py:85: error: Cannot determine type of "_all_variables"  [has-type]
src/pybamm/solvers/summary_variable.py:103: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "_get_electrode_soh_sims_full"  [union-attr]
src/pybamm/solvers/summary_variable.py:105: error: Incompatible types in assignment (expression has type "list[Any]", variable has type "None")  [assignment]
src/pybamm/solvers/summary_variable.py:126: error: Incompatible return value type (got "ndarray[Any, dtype[Any]]", expected "float | list[float]")  [return-value]
src/pybamm/solvers/summary_variable.py:151: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:153: error: "None" has no attribute "__iter__" (not iterable)  [attr-defined]
src/pybamm/solvers/summary_variable.py:184: error: Item "None" of "ElectrodeSOHSolver | None" has no attribute "solve"  [union-attr]
src/pybamm/solvers/solution.py:160: error: Missing return statement  [return]
src/pybamm/solvers/processed_variable_time_integral.py:20: error: Argument "initial_condition" to "ProcessedVariableTimeIntegral" has incompatible type "float"; expected "ndarray[Any, dtype[floating[_64Bit]]]"  [arg-type]
src/pybamm/solvers/idaklu_jax.py:262: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[floating[_64Bit]]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:262: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:262: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/solvers/idaklu_jax.py:295: error: Incompatible default for argument "t" (default has type "None", argument has type "ndarray[Any, dtype[floating[_64Bit]]]")  [assignment]
src/pybamm/solvers/idaklu_jax.py:295: note: PEP 484 prohibits implicit Optional. Accordingly, mypy has changed its default to no_implicit_optional=True
src/pybamm/solvers/idaklu_jax.py:295: note: Use https://github.com/hauntsaninja/no_implicit_optional to automatically upgrade your codebase
src/pybamm/expression_tree/functions.py:164: error: Expected iterable as variadic argument  [misc]
src/pybamm/expression_tree/functions.py:173: error: Argument 1 to "_function_new_copy" of "Function" has incompatible type "Symbol"; expected "list[Any]"  [arg-type]
src/pybamm/expression_tree/concatenations.py:480: error: Argument 1 to "intersect" has incompatible type "str | None"; expected "str"  [arg-type]
src/pybamm/expression_tree/concatenations.py:481: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:484: error: Argument 1 to "len" has incompatible type "str | None"; expected "Sized"  [arg-type]
src/pybamm/expression_tree/concatenations.py:485: error: Value of type "str | None" is not indexable  [index]
src/pybamm/expression_tree/concatenations.py:546: error: Cannot determine type of "child"  [has-type]
src/pybamm/citations.py:34: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
src/pybamm/expression_tree/unary_operators.py:74: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/symbol.py:974: error: Incompatible return value type (got "list[Symbol]", expected "Symbol")  [return-value]
src/pybamm/expression_tree/symbol.py:996: error: Argument 2 to "Symbol" has incompatible type "Symbol"; expected "Sequence[Symbol] | None"  [arg-type]
src/pybamm/expression_tree/broadcasts.py:82: error: "Broadcast" has no attribute "broadcast_domain"  [attr-defined]
src/pybamm/expression_tree/binary_operators.py:131: error: Missing positional argument "right_child" in call to "BinaryOperator"  [call-arg]
src/pybamm/expression_tree/binary_operators.py:131: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/expression_tree/binary_operators.py:135: error: Value of type "Symbol" is not indexable  [index]
src/pybamm/experiment/step/base_step.py:143: error: Argument 3 to "Interpolant" has incompatible type "Subtraction"; expected "Sequence[Symbol] | Time"  [arg-type]
src/pybamm/experiment/experiment.py:59: error: Incompatible types in assignment (expression has type "tuple[str | BaseStep]", variable has type "str | tuple[str] | BaseStep")  [assignment]
src/pybamm/experiment/experiment.py:62: error: Argument 1 to "len" has incompatible type "str | tuple[str] | BaseStep"; expected "Sized"  [arg-type]
src/pybamm/experiment/experiment.py:64: error: Item "BaseStep" of "str | tuple[str] | BaseStep" has no attribute "__iter__" (not iterable)  [union-attr]
src/pybamm/solvers/base_solver.py:97: error: Name "root_method" already defined on line 85  [no-redef]
src/pybamm/solvers/base_solver.py:97: error: "Callable[[BaseSolver], Any]" has no attribute "setter"  [attr-defined]
src/pybamm/solvers/base_solver.py:1124: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1125: error: Need type annotation for "initial_conditions"  [var-annotated]
src/pybamm/solvers/base_solver.py:1131: error: "BaseModel" has no attribute "calculate_sensitivities"  [attr-defined]
src/pybamm/solvers/base_solver.py:1150: error: Incompatible return value type (got "list[Any]", expected "tuple[Any, ...]")  [return-value]
tests/unit/test_parameters/test_bpx.py:11: note: By default the bodies of untyped functions are not checked, consider using --check-untyped-defs  [annotation-unchecked]
tests/unit/test_expression_tree/test_binary_operators.py:14: error: Need type annotation for "EMPTY_DOMAINS"  [var-annotated]
examples/scripts/run_ecmd.py:14: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/run_ecm.py:9: error: List item 0 has incompatible type "tuple[str, str, str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/minimal_example_of_lookup_tables.py:37: error: Name "D_s_n" already defined on line 25  [no-redef]
examples/scripts/experiment_drive_cycle.py:32: error: List item 0 has incompatible type "tuple[str, str, str, Any, str, Any, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/SPMe_step.py:47: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPMe_step.py:49: error: Value of type "Any | None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:59: error: Value of type "None" is not indexable  [index]
examples/scripts/SPM_compare_particle_grid.py:60: error: Value of type "None" is not indexable  [index]
examples/scripts/MSMR.py:29: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/gitt.py:8: error: List item 0 has incompatible type "tuple[str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
examples/scripts/experimental_protocols/cccv.py:10: error: List item 0 has incompatible type "tuple[str, str, str, str, str]"; expected "str | tuple[str] | BaseStep"  [list-item]
docs/conf.py:110: error: Need type annotation for "suppress_warnings" (hint: "suppress_warnings: list[<type>] = ...")  [var-annotated]
docs/conf.py:202: error: Incompatible types in assignment (expression has type "bool", target has type "str")  [assignment]
docs/conf.py:221: error: Need type annotation for "latex_elements" (hint: "latex_elements: dict[<type>, <type>] = ...")  [var-annotated]
docs/conf.py:338: error: Incompatible types in assignment (expression has type "str | None", variable has type "str")  [assignment]
docs/conf.py:494: error: Dict entry 0 has incompatible type "str": "ParameterSets"; expected "str": "str"  [dict-item]
Found 59 errors in 26 files (checked 581 source files)

@vidipsingh vidipsingh requested a review from a team as a code owner March 8, 2025 18:28
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@vidipsingh
Copy link
Contributor Author

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@Saransh-cpp, I have attached the output of mypy run in the PR description.

@agriyakhetarpal
Copy link
Member

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@Saransh-cpp, I have attached the output of mypy run in the PR description.

@vidipsingh, it would be nice if you could:

  • paste the output as text (wrapped in triple backticks, i.e., as code), rather than as an image – so that it is easier to read and copy from or quote
  • display what was fixed with a "Before" v.s. "After" comparison (you may use the GitHub / commands and choose "Details" to wrap the code blocks inside a collapsible dropdown section)

Thank you!

@vidipsingh
Copy link
Contributor Author

vidipsingh commented Mar 10, 2025

Thanks, @vidipsingh! Could you attach the output of mypy run in the PR description?

@Saransh-cpp, I have attached the output of mypy run in the PR description.

@vidipsingh, it would be nice if you could:

  • paste the output as text (wrapped in triple backticks, i.e., as code), rather than as an image – so that it is easier to read and copy from or quote
  • display what was fixed with a "Before" v.s. "After" comparison (you may use the GitHub / commands and choose "Details" to wrap the code blocks inside a collapsible dropdown section)

Thank you!

Thanks for the feedback, @agriyakhetarpal!

I’ll replace the image with the mypy output in code blocks and will also add "Before" vs. "After" comparison using collapsible sections.

Just to clarify, are you referring to the "Before" vs. "After" comparison of the mypy run output or the code changes?

@Saransh-cpp
Copy link
Member

Just to clarify, are you referring to the "Before" vs. "After" comparison of the mypy run output or the code changes?

The mypy run!

@vidipsingh
Copy link
Contributor Author

@Saransh-cpp @agriyakhetarpal

I’ve added the mypy run output for both "Before" and "After" in their respective collapsible sections in the PR description.
However, it seems that the outputs are the same for both.

Please let me know if any changes are needed!

Copy link

codecov bot commented Mar 10, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 98.57%. Comparing base (34186fe) to head (0a017db).
Report is 1 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #4901   +/-   ##
========================================
  Coverage    98.57%   98.57%           
========================================
  Files          304      304           
  Lines        23645    23656   +11     
========================================
+ Hits         23309    23320   +11     
  Misses         336      336           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any in all the places where you are using it? Thank you!

import pybamm


@dataclass
class ProcessedVariableTimeIntegral:
method: Literal["discrete", "continuous"]
initial_condition: npt.NDArray
discrete_times: Optional[npt.NDArray]
initial_condition: npt.NDArray[np.float64]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
initial_condition: npt.NDArray[np.float64]
initial_condition: float | npt.NDArray[np.float64]

@@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs):

def jax_value(
self,
t: npt.NDArray = None,
t: npt.NDArray[np.float64] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t: npt.NDArray[np.float64] = None,
t: npt.NDArray[np.float64] | None = None,

@@ -292,7 +292,7 @@ def jax_value(

def jax_grad(
self,
t: npt.NDArray = None,
t: npt.NDArray[np.float64] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t: npt.NDArray[np.float64] = None,
t: npt.NDArray[np.float64] | None = None,

@@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array):

def _jax_solve(
self,
t: Union[float, npt.NDArray],
t: Union[float, npt.NDArray[np.float64]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
t: Union[float, npt.NDArray[np.float64]],
t: float | npt.NDArray[np.float64],

@@ -410,7 +410,7 @@ def _jax_solve(

def _jax_jvp_impl(
self,
*args: Union[npt.NDArray],
*args: Union[npt.NDArray[np.float64]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
*args: Union[npt.NDArray[np.float64]],
*args: npt.NDArray[np.float64],

@@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs(

def _jax_vjp_impl(
self,
y_bar: npt.NDArray,
y_bar: npt.NDArray[np.float64],
invar: Union[str, int], # index or name of input variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
invar: Union[str, int], # index or name of input variable
invar: str | int, # index or name of input variable

initial_condition: npt.NDArray
discrete_times: Optional[npt.NDArray]
initial_condition: npt.NDArray[np.float64]
discrete_times: Optional[npt.NDArray[np.float64]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will -

npt.NDArray[np.float64] | None

work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will -

npt.NDArray[np.float64] | None

work?

I think npt.NDArray[np.float64] | None will work, let me look into it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will -

npt.NDArray[np.float64] | None

work?

npt.NDArray[np.float64] | None will not work for initial_condition because it excludes float (e.g., 0.0 or scalar from evaluate()).
It will work for discrete_times since it already matches npt.NDArray[np.float64] | None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about npt.NDArray[np.float64] | None | float then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about npt.NDArray[np.float64] | None | float then?

Yes, npt.NDArray[np.float64] | None | float will work for initial_condition since it covers scalar float (e.g., 0.0) and arrays.

@vidipsingh
Copy link
Contributor Author

Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any in all the places where you are using it? Thank you!

Thank you for the feedback! I used Any because I wasn't entirely sure of the exact type to apply in those cases. If you have any suggestions, I'd be happy to make the changes.

I'll also work on the other suggested changes. Appreciate it!

@Saransh-cpp
Copy link
Member

@vidipsingh let me know when this is ready for a review again!

@Saransh-cpp Saransh-cpp marked this pull request as draft March 18, 2025 13:03
@vidipsingh
Copy link
Contributor Author

@vidipsingh let me know when this is ready for a review again!

@Saransh-cpp I think it is ready for review now!

Saransh-cpp
Saransh-cpp previously approved these changes Mar 20, 2025
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @vidipsingh! Maybe we should create a new issue that aims to narrow down the Any dtype.

@Saransh-cpp
Copy link
Member

The CI should work once you add from __future__ import annotations on top of the failing file.

@vidipsingh
Copy link
Contributor Author

Thanks, @vidipsingh! Maybe we should create a new issue that aims to narrow down the Any dtype.

Sure, I will create a new issue for it then.

@Saransh-cpp Saransh-cpp marked this pull request as ready for review March 29, 2025 22:12
@Saransh-cpp Saransh-cpp marked this pull request as draft March 29, 2025 22:13
@vidipsingh vidipsingh marked this pull request as ready for review April 5, 2025 14:29
@vidipsingh
Copy link
Contributor Author

The CI should work once you add from __future__ import annotations on top of the failing file.

Hi @Saransh-cpp, I have added from __future__ import annotations on top of the failing file as suggested.
Please review the PR and let me know if anything else is required.

@vidipsingh
Copy link
Contributor Author

Hi @Saransh-cpp, Just pinging you here to check if any changes are required for this PR.
It would be great if you could review it and let me know if anything needs to be updated.

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @vidipsingh, sorry for taking too long. The tests are failing with the error -

ImportError while loading conftest '/Users/runner/work/PyBaMM/PyBaMM/conftest.py'.
conftest.py:3: in <module>
    import pybamm
src/pybamm/__init__.py:176: in <module>
    from .solvers.idaklu_jax import IDAKLUJax
src/pybamm/solvers/idaklu_jax.py:25: in <module>
    class IDAKLUJax:
src/pybamm/solvers/idaklu_jax.py:261: in IDAKLUJax
    t: npt.NDArray[np.float64] | None = None,
E   TypeError: unsupported operand type(s) for |: 'types.GenericAlias' and 'NoneType'
nox > Command python -m pytest -m unit failed with exit code 4

which looks like a missing import. Could you please fix this? Thanks!

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, thanks, @vidipsingh!

@vidipsingh
Copy link
Contributor Author

Looks good now, thanks, @vidipsingh!

Great! Is there anything else to be done, or is this PR good to be merged?

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented May 4, 2025

Thanks! It is good to be merged, but @Saransh-cpp and I moved ourselves out of the pybamm-team/maintainers team and into another one, and the permissions for both aren't in sync yet. @kratman should be able to merge this (and can review if needed, as he previously reviewed the PR).

@vidipsingh
Copy link
Contributor Author

Thanks! It is good to be merged, but @Saransh-cpp and I moved ourselves out of the pybamm-team/maintainers team and into another one, and the permissions for both aren't in sync yet. @kratman should be able to merge this (and can review if needed, as he previously reviewed the PR).

Got it, thanks for the heads-up! Let's wait for @kratman to review and merge.

Copy link
Contributor

@kratman kratman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine

@kratman kratman merged commit 7526b59 into pybamm-team:develop May 5, 2025
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make npt.NDArray type hints more specific
4 participants