-
Notifications
You must be signed in to change notification settings - Fork 128
Mla splitkv enhance split alg inte #1233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…nhance_split_alg_inte
…nhance_split_alg_inte
…nhance_split_alg_inte
…_schema which doesn't allow dict as parameter. Thus, change the API for metadata.
…it's not compatible with hip graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the MLA (Multi-head Latent Attention) split key-value algorithm with significant new functionality including persistent thread group support, sparse attention capabilities, and fp8 quantization support. The changes introduce metadata generation for optimized work distribution and a reduce kernel for merging partial results.
- Adds persistent thread group implementation for variable query/output lengths
- Implements sparse attention with top-k token selection
- Integrates fp8 quantization support for both Q and KV tensors
Reviewed Changes
Copilot reviewed 52 out of 85 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_mla_sparse.py | New test file for sparse MLA attention with top-k token selection |
| op_tests/test_mla_persistent*.py | New test files for persistent thread group MLA implementation |
| csrc/py_itfs_cu/asm_mla.cu | Updated to support persistent mode, fp8 datatypes, and new metadata parameters |
| csrc/kernels/mla/reduce.cu | New reduce kernel for merging partial attention outputs |
| csrc/kernels/mla/metadata*.cuh | New metadata generation kernels for work distribution |
| aiter/mla.py | Updated decode forward pass to use new metadata and reduce operations |
| csrc/include/mla.h | New header defining MLA data structures and function signatures |
| Various copyright headers | Updated copyright format from (c) to (C) |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist