-
Notifications
You must be signed in to change notification settings - Fork 135
Decompose Tridiagonal Solve into core steps #1382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decompose Tridiagonal Solve into core steps #1382
Conversation
cc4fab6
to
c72326e
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
2436cd1
to
aceab09
Compare
Is this blocked by anything? Or just needs a bit of polishing to make mypy happy etc |
Needs a bit rework after the LU PR |
209f71f
to
9652c5d
Compare
70da029
to
6bc5484
Compare
be7aa49
to
6c985c9
Compare
dummy_arrays = [ | ||
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv) | ||
] | ||
# Seems to always be float64? | ||
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here -- get_lapack_output_dtype_from_inputs
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean? Is there such a helper? Google doesn't return anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant make that function to hold the logic for making dummy arrays and getting the out_dtype form get_lapack_funcs
6c985c9
to
2e6c90f
Compare
2e6c90f
to
758fcdd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving now, idk if there's more you want to adjust but I'm happy
758fcdd
to
d037bd0
Compare
d037bd0
to
acd9219
Compare
acd9219
to
4847344
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (72.55%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1382 +/- ##
==========================================
- Coverage 82.12% 82.09% -0.03%
==========================================
Files 211 212 +1
Lines 49757 49965 +208
Branches 8819 8858 +39
==========================================
+ Hits 40862 41018 +156
- Misses 6715 6757 +42
- Partials 2180 2190 +10
🚀 New features to boost your workflow:
|
This PR builds on top of #1396, extending it to the tridiagonal case.
It also adds a rewrite for when extracting diagonals out a symbolically allocated diagonal matrix (via set_subtensor or AllocDiag after it gets inlined during specialization)
TODO: