-
-
Notifications
You must be signed in to change notification settings - Fork 616
Make npt.NDArray type hints more specific with dtype #4901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make npt.NDArray type hints more specific with dtype #4901
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @vidipsingh! Could you attach the output of mypy
run in the PR description?
@Saransh-cpp, I have attached the output of |
@vidipsingh, it would be nice if you could:
Thank you! |
Thanks for the feedback, @agriyakhetarpal! I’ll replace the image with the Just to clarify, are you referring to the "Before" vs. "After" comparison of the |
The |
I’ve added the Please let me know if any changes are needed! |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #4901 +/- ##
========================================
Coverage 98.57% 98.57%
========================================
Files 304 304
Lines 23645 23656 +11
========================================
+ Hits 23309 23320 +11
Misses 336 336 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @vidipsingh! See my comments below. Could you please also comment why you are using Any
in all the places where you are using it? Thank you!
import pybamm | ||
|
||
|
||
@dataclass | ||
class ProcessedVariableTimeIntegral: | ||
method: Literal["discrete", "continuous"] | ||
initial_condition: npt.NDArray | ||
discrete_times: Optional[npt.NDArray] | ||
initial_condition: npt.NDArray[np.float64] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
initial_condition: npt.NDArray[np.float64] | |
initial_condition: float | npt.NDArray[np.float64] |
src/pybamm/solvers/idaklu_jax.py
Outdated
@@ -259,7 +259,7 @@ def f_isolated(*args, **kwargs): | |||
|
|||
def jax_value( | |||
self, | |||
t: npt.NDArray = None, | |||
t: npt.NDArray[np.float64] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t: npt.NDArray[np.float64] = None, | |
t: npt.NDArray[np.float64] | None = None, |
src/pybamm/solvers/idaklu_jax.py
Outdated
@@ -292,7 +292,7 @@ def jax_value( | |||
|
|||
def jax_grad( | |||
self, | |||
t: npt.NDArray = None, | |||
t: npt.NDArray[np.float64] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t: npt.NDArray[np.float64] = None, | |
t: npt.NDArray[np.float64] | None = None, |
src/pybamm/solvers/idaklu_jax.py
Outdated
@@ -396,9 +396,9 @@ def _jax_solve_array_inputs(self, t, inputs_array): | |||
|
|||
def _jax_solve( | |||
self, | |||
t: Union[float, npt.NDArray], | |||
t: Union[float, npt.NDArray[np.float64]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t: Union[float, npt.NDArray[np.float64]], | |
t: float | npt.NDArray[np.float64], |
src/pybamm/solvers/idaklu_jax.py
Outdated
@@ -410,7 +410,7 @@ def _jax_solve( | |||
|
|||
def _jax_jvp_impl( | |||
self, | |||
*args: Union[npt.NDArray], | |||
*args: Union[npt.NDArray[np.float64]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
*args: Union[npt.NDArray[np.float64]], | |
*args: npt.NDArray[np.float64], |
src/pybamm/solvers/idaklu_jax.py
Outdated
@@ -455,9 +455,9 @@ def _jax_jvp_impl_array_inputs( | |||
|
|||
def _jax_vjp_impl( | |||
self, | |||
y_bar: npt.NDArray, | |||
y_bar: npt.NDArray[np.float64], | |||
invar: Union[str, int], # index or name of input variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
invar: Union[str, int], # index or name of input variable | |
invar: str | int, # index or name of input variable |
initial_condition: npt.NDArray | ||
discrete_times: Optional[npt.NDArray] | ||
initial_condition: npt.NDArray[np.float64] | ||
discrete_times: Optional[npt.NDArray[np.float64]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will -
npt.NDArray[np.float64] | None
work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will -
npt.NDArray[np.float64] | Nonework?
I think npt.NDArray[np.float64] | None
will work, let me look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will -
npt.NDArray[np.float64] | Nonework?
npt.NDArray[np.float64] | None
will not work for initial_condition
because it excludes float
(e.g., 0.0
or scalar from evaluate()
).
It will work for discrete_times
since it already matches npt.NDArray[np.float64] | None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about npt.NDArray[np.float64] | None | float
then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
npt.NDArray[np.float64] | None | float
then?
Yes, npt.NDArray[np.float64] | None | float
will work for initial_condition
since it covers scalar float
(e.g., 0.0
) and arrays.
Thank you for the feedback! I used I'll also work on the other suggested changes. Appreciate it! |
@vidipsingh let me know when this is ready for a review again! |
@Saransh-cpp I think it is ready for review now! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @vidipsingh! Maybe we should create a new issue that aims to narrow down the Any
dtype.
The CI should work once you add |
Sure, I will create a new issue for it then. |
Co-authored-by: Saransh Chopra <[email protected]>
Co-authored-by: Saransh Chopra <[email protected]>
Hi @Saransh-cpp, I have added |
Hi @Saransh-cpp, Just pinging you here to check if any changes are required for this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @vidipsingh, sorry for taking too long. The tests are failing with the error -
ImportError while loading conftest '/Users/runner/work/PyBaMM/PyBaMM/conftest.py'.
conftest.py:3: in <module>
import pybamm
src/pybamm/__init__.py:176: in <module>
from .solvers.idaklu_jax import IDAKLUJax
src/pybamm/solvers/idaklu_jax.py:25: in <module>
class IDAKLUJax:
src/pybamm/solvers/idaklu_jax.py:261: in IDAKLUJax
t: npt.NDArray[np.float64] | None = None,
E TypeError: unsupported operand type(s) for |: 'types.GenericAlias' and 'NoneType'
nox > Command python -m pytest -m unit failed with exit code 4
which looks like a missing import. Could you please fix this? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now, thanks, @vidipsingh!
Great! Is there anything else to be done, or is this PR good to be merged? |
Thanks! It is good to be merged, but @Saransh-cpp and I moved ourselves out of the pybamm-team/maintainers team and into another one, and the permissions for both aren't in sync yet. @kratman should be able to merge this (and can review if needed, as he previously reviewed the PR). |
Got it, thanks for the heads-up! Let's wait for @kratman to review and merge. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fine
Description
This PR refines
npt.NDArray
type hints in PyBaMM by adding explicitdtype
(e.g.,np.float64
for time/state arrays,Any
for variable cases).Fixes: #4900
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #)
Important checks:
Please confirm the following before marking the PR as ready for review:
nox -s pre-commit
nox -s tests
nox -s doctests
mypy Output (Before):
mypy Output (After):