-
Couldn't load subscription status.
- Fork 68
Hacky patch to support correct vectorization factor of nvfp4 for pointwise scheduler #5428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit d88ee26 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
| # However, casting complex values to real discards the imaginary | ||
| # part, so skip complex dtypes. | ||
| if not ref_out.dtype.is_complex: | ||
| if ref_out.dtype == torch.float4_e2m1fn_x2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdspring1 probably not the cleanest thing, just a naive python patch
|
!test |
|
!test |
| // sub-byte length. | ||
| params->vectorization_factor = std::min( | ||
| max_vect_factor, | ||
| has_sub_byte ? std::max(2l, max_vect_factor) : max_vect_factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - I need something like this as well. I thought I may need this to have a minimum of 4/8(bf16).
I guess these are all computed at compile time so there's no information about alignment yet. Can increasing the vectorization width (or having any vectorization at all) lead to incorrect behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't have caused more issue than what would have already been there.
This only modifies heuristics on the lower bound of max_vect_factor when fp4 presents. Actual alignment is constrained on top of this with the outer std::min.
If we cannot do vectorization at the proper width, it would still not do that and codegen would fail later. (For fp4 types, we'll hit assert in the runtime function).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should assert on the params->vectorization_factor here in the case that we have a packed fp4 type and we aren't able to vectorize to 2 or more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
^ I would prefer not.
generally speaking, earlier error seems like a good idea. But vectorization should not be the answer.
It's really the inline that's failing us. i.e. we don't necessarily need the fp4 TV to be vectorized if we have correct inline. We really only needed to have a packed storage.
I don't know how to fix that properly yet and I wanted it to work for now. That's why I was calling it a hacky patch. 😰
|
|
||
| # Check that the values of all outputs match | ||
| for ref_out, cap_out in zip(reference_outputs, captured_outputs): | ||
| # torch.allclose does not work with fp8 datatype, so cast to fp64. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be useful to add to the comment here about the packed fp4.
|
!test |
|
!test |
|
Looks like CI is broken again. But test has cleared earlier. (only added comment since [2917c9d]. I'll merge this as-is after build clears. |
|
!test |
Array<__e2m1, size>requiressize % 2 == 0since it's a packed dtype.Existing vectorization heuristics in pointwise scheduler didn't consider this when capping its vectorization_factor. This PR is just a quick patch to add a lower bound of 2 when packed dtype is seen on vectorized inputs/outputs.