Skip to content

Decompose Tridiagonal Solve into core steps #1382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 29, 2025

This PR builds on top of #1396, extending it to the tridiagonal case.

It also adds a rewrite for when extracting diagonals out a symbolically allocated diagonal matrix (via set_subtensor or AllocDiag after it gets inlined during specialization)

TODO:

  • Numba implementation and exclusion from other backends rewrites

@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch from cc4fab6 to c72326e Compare May 7, 2025 05:27
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch 4 times, most recently from 2436cd1 to aceab09 Compare May 8, 2025 16:55
@jessegrabowski
Copy link
Member

jessegrabowski commented May 14, 2025

Is this blocked by anything? Or just needs a bit of polishing to make mypy happy etc

@ricardoV94
Copy link
Member Author

Needs a bit rework after the LU PR

@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch 4 times, most recently from 209f71f to 9652c5d Compare May 19, 2025 11:41
@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch 5 times, most recently from 70da029 to 6bc5484 Compare May 23, 2025 11:31
@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch 3 times, most recently from be7aa49 to 6c985c9 Compare June 9, 2025 12:50
Comment on lines +135 to +152
dummy_arrays = [
np.zeros((), dtype=inp.type.dtype) for inp in (dl, d, du, du2, ipiv)
]
# Seems to always be float64?
out_dtype = get_lapack_funcs("gttrs", dummy_arrays).dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here -- get_lapack_output_dtype_from_inputs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? Is there such a helper? Google doesn't return anything

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant make that function to hold the logic for making dummy arrays and getting the out_dtype form get_lapack_funcs

@jessegrabowski jessegrabowski marked this pull request as ready for review June 9, 2025 13:50
@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch from 6c985c9 to 2e6c90f Compare June 9, 2025 17:44
Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving now, idk if there's more you want to adjust but I'm happy

@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch from 758fcdd to d037bd0 Compare June 10, 2025 08:12
@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch from d037bd0 to acd9219 Compare June 10, 2025 08:40
@ricardoV94 ricardoV94 force-pushed the decompose_tridiagonal_solve branch from acd9219 to 4847344 Compare June 10, 2025 13:10
Copy link

codecov bot commented Jun 10, 2025

Codecov Report

Attention: Patch coverage is 72.55814% with 59 lines in your changes missing coverage. Please review.

Project coverage is 82.09%. Comparing base (d10f245) to head (4847344).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/_linalg/solve/tridiagonal.py 76.57% 20 Missing and 6 partials ⚠️
...or/link/numba/dispatch/linalg/solve/tridiagonal.py 41.46% 24 Missing ⚠️
pytensor/tensor/rewriting/subtensor.py 84.09% 3 Missing and 4 partials ⚠️
pytensor/tensor/_linalg/solve/rewriting.py 89.47% 0 Missing and 2 partials ⚠️

❌ Your patch status has failed because the patch coverage (72.55%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1382      +/-   ##
==========================================
- Coverage   82.12%   82.09%   -0.03%     
==========================================
  Files         211      212       +1     
  Lines       49757    49965     +208     
  Branches     8819     8858      +39     
==========================================
+ Hits        40862    41018     +156     
- Misses       6715     6757      +42     
- Partials     2180     2190      +10     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 84.72% <ø> (ø)
pytensor/tensor/subtensor.py 89.98% <ø> (ø)
pytensor/tensor/_linalg/solve/rewriting.py 93.49% <89.47%> (+1.04%) ⬆️
pytensor/tensor/rewriting/subtensor.py 89.81% <84.09%> (-0.34%) ⬇️
...or/link/numba/dispatch/linalg/solve/tridiagonal.py 55.39% <41.46%> (-2.86%) ⬇️
pytensor/tensor/_linalg/solve/tridiagonal.py 76.57% <76.57%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 merged commit d88c735 into pymc-devs:main Jun 10, 2025
72 of 73 checks passed
@ricardoV94 ricardoV94 deleted the decompose_tridiagonal_solve branch June 10, 2025 14:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants