-
Couldn't load subscription status.
- Fork 67
Create a new node for Block Quantization to NVFP4 and plumb it to a device function. #5266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit c0cd7f9 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't review in detail on all the relaxed checks.
My main question is how are we handling/planning to handle indexing in runtime function? Are we going to just restrict the scheduler to ensure they comply with the index requirement from the runtime function, that felt like too restrictive to me.
|
|
||
| // This division should be replaced with a multiplication | ||
| // by a reciprocal for better performance. | ||
| float scaled_max = block_max / 6.000000000e+00f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm uncertain how the math would work with global scaling factor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The math in the device function is not correct (doesn't match with the Python reference in test_narrow_precision.py.
In a branch I have support for global scaling factor and a modified kernel with the new math.
I think I'll update this PR with the new math.
The only problem is that, I'll need to update Xiang's old tests as well to reflect that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'll leave the math in the device function as is for now and have a separate PR which updates the math in the device function and the older C++ tests.
My main question is how are we handling/planning to handle indexing in runtime function? Are we going to just restrict the scheduler to ensure they comply with the index requirement from the runtime function, that felt like too restrictive to me. For now that's what I was doing. |
|
@jjsjann123 I updated the device function (and the codegen). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to revisit the RP and clean it up.
|
!test |
| FusionGuard fg(fusion.get()); | ||
| createNVFP4QunatizationFusion(fusion.get(), DataType::Float); | ||
|
|
||
| FusionExecutorCache fec(std::move(fusion)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't you mention the quantization originally done by @zasdfgbnm has some bugs? Is that fixed?
|
|
||
| // I'd like to check that the inner dimension of the input | ||
| // is divisble by 16. | ||
| void handle(BlockQuantizationOp* bqop) final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam This is the validation function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a list of things to validate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the comment with a list of things I want to check.
I need to rewrite the code for validation and the second half is not correct. I'll send another commit to fix that.
|
!test |
The PR:
For the 2D scheduling: