Skip to content

Implement gemv numba dispatch #1418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jessegrabowski opened this issue May 24, 2025 · 2 comments · Fixed by #1426
Closed

Implement gemv numba dispatch #1418

jessegrabowski opened this issue May 24, 2025 · 2 comments · Fixed by #1426
Labels
linalg Linear algebra numba

Comments

@jessegrabowski
Copy link
Member

Description

When working on #1416 I found that numba has terrible performance on matrix-vector multiplication:

Image

@ricardoV94 thinks this is because numba probably only uses gemm for everything, and never uses gemv. Since we have a GEMV Op already, it will be very easy to follow the pattern I used in #1416 to make a dispatch for GEMV.

We might also have to add the GEMV rewrite to the numba mode as well -- I know they're disabled for jax, for example. Haven't checked for numba, but something to be aware of.

@jessegrabowski jessegrabowski added numba linalg Linear algebra labels May 24, 2025
@jessegrabowski jessegrabowski changed the title Implement gmbv numba dispatch Implement gemv numba dispatch May 24, 2025
@ricardoV94
Copy link
Member

Yes all BLAS rewrites are disabled on NUMBA, those are the ones that add things like GEMM, GEMV, GER

@ricardoV94
Copy link
Member

Here is a MRE:

import numpy as np
import pytensor
import pytensor.tensor as pt

A = pt.matrix("A")
x = pt.vector("x")
out = A @ x

c_fn = pytensor.function([A, x], out, trust_input=True)
c_fn.dprint()
# CGemv{inplace} [id A] 2
#  ├─ AllocEmpty{dtype='float64'} [id B] 1
#  │  └─ Shape_i{0} [id C] 0
#  │     └─ A [id D]
#  ├─ 1.0 [id E]
#  ├─ A [id D]
#  ├─ x [id F]
#  └─ 0.0 [id G]

numba_fn = pytensor.function([A, x], out, mode="NUMBA", trust_input=True)
numba_fn.dprint()
# Squeeze{axis=1} [id A] 2
#  └─ dot [id B] 1
#     ├─ A [id C]
#     └─ ExpandDims{axis=1} [id D] 0
#        └─ x [id E]

rng = np.random.default_rng(1)
A_test = rng.normal(size=(1024, 512))
x_test = rng.normal(size=(512,))
np.testing.assert_allclose(c_fn(A_test, x_test), numba_fn(A_test, x_test))
%timeit c_fn(A_test, x_test) # 338 μs ± 8.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
%timeit numba_fn(A_test, x_test)  # 6.6 ms ± 978 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
linalg Linear algebra numba
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants