-
Notifications
You must be signed in to change notification settings - Fork 134
Added in GQA and 64-bit indexing #1226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for Grouped Query Attention (GQA) and 64-bit indexing to the lean attention implementation. GQA allows different numbers of query heads (hq) and key/value heads (hk), which improves efficiency in transformer models by reducing memory bandwidth for key/value tensors.
Key changes:
- Refactored function signatures to separate query heads (
hq) from key/value heads (hk) - Added GQA support with head expansion logic in reference implementation and kernel mapping
- Implemented 64-bit indexing to handle large tensor offsets that exceed 32-bit limits
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| op_tests/triton_tests/test_la.py | Updated test parameters and helper functions to support separate hq and hk parameters; added GQA head expansion in reference implementation |
| op_tests/op_benchmarks/triton/bench_la.py | Updated benchmark configurations to include hq and hk parameters with new GQA test cases |
| aiter/ops/triton/lean_atten.py | Added GQA group size calculation, 64-bit indexing support, and enhanced runtime safety checks for buffer validation |
| aiter/ops/triton/_triton_kernels/lean_atten.py | Implemented GQA head mapping in kernel, added 64-bit pointer arithmetic, and padding mask support for irregular head dimensions |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Added in support for gqa