-
Notifications
You must be signed in to change notification settings - Fork 35
varlen maba #352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
varlen maba #352
Conversation
@RaymondLi0 if you have some time to double check this |
Approach looks great! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking into this! Added a few comments, otherwise LGTM
fast_llm/layers/ssm/mamba2.py
Outdated
delta_softplus=True, | ||
) | ||
if not _mamba_varlen: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe throw an error here if cu_seqlens is not None
, saying that this version of mamba does not support varlen inputs
fast_llm/layers/ssm/preprocessors.py
Outdated
@@ -0,0 +1,65 @@ | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's call this file preprocessing.py
to follow other layers
fast_llm/layers/ssm/preprocessors.py
Outdated
""" | ||
Simplified preprocessor that does not take into account micro-sequences. | ||
""" | ||
sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if TransformerKwargs.sequence_lengths not in kwargs
, there is nothing to do.
(iiuc this is the case when cross_document_attention
is set to True.)
fast_llm/layers/ssm/preprocessors.py
Outdated
Simplified preprocessor that does not take into account micro-sequences. | ||
""" | ||
sequence_lengths = kwargs[TransformerKwargs.sequence_lengths] | ||
if "cu_seqlens" in kwargs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if "cu_seqlens" in kwargs: | |
if TransformerKwargs.cu_seqlens in kwargs: |
Varlen mamba to enable packing.
🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.