-
Notifications
You must be signed in to change notification settings - Fork 70
Create grouped_mm for bf16 and fp16 inputs on Blackwell
#5101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
rdspring1
commented
Sep 2, 2025
- Cutlass support for Issue Better fallback for bf16 GroupedMmaOp #5007
|
Review updated until commit 34cc5f5 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
jjsjann123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as in the other PR.
On top of that, some nitpick about comments that should have dropped all the nvfp4 related stuff that's no longer relevant with 16bit types.
4fcd1a3 to
f43d0a5
Compare
3e2daf1 to
3eba5ae
Compare
ff39085 to
21f32c9
Compare
|
!test |
jjsjann123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm~
| #include "cutlass/util/reference/host/tensor_compare.h" | ||
| #include "cutlass/util/reference/host/tensor_fill.h" | ||
| #include "cutlass/util/reference/host/tensor_norm.h" | ||
| #include "cutlass/util/tensor_view_io.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning up the header~~~
jacobhinkle
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Left some minor comments. These comments apply to the scaled input version as well. I just didn't notice previously.
| template <typename T, bool is_single_sm> | ||
| struct KernelTraits; | ||
|
|
||
| // Kernel traits for FP16 output | ||
| template <> | ||
| struct KernelTraits<cutlass::half_t, true> { | ||
| using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(cutlass::half_t)>>; | ||
| using ClusterShape = Shape<_1, _1, _1>; | ||
| using KernelSchedule = | ||
| cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; | ||
| using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; | ||
| }; | ||
|
|
||
| // Kernel traits for BFloat16 output | ||
| template <> | ||
| struct KernelTraits<cutlass::bfloat16_t, true> { | ||
| using MmaTileShape = | ||
| Shape<_128, _256, Int<128 / sizeof(cutlass::bfloat16_t)>>; | ||
| using ClusterShape = Shape<_1, _1, _1>; | ||
| using KernelSchedule = | ||
| cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; | ||
| using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we could just handle this in a single template definition using sizeof(T) and using
std::conditional<is_single_sm,
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100,
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100>::type| run_get_group_gemm_starts( | ||
| a_ptrs, | ||
| b_ptrs, | ||
| out_ptrs, | ||
| a, | ||
| b, | ||
| output, | ||
| expert_offsets, | ||
| problem_sizes, | ||
| M, | ||
| N, | ||
| K, | ||
| stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it normal to need to launch two kernels for every invocation? I would imagine for common use cases we'd have a static shape meaning all the offsets could be precomputed even if they live on the GPU and we'd just change the base data pointer for each invocation. I wonder what is done for small problems or inference use cases..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Can the CutlassExecutor cache based on the shapes to avoid calling the run_get_group_gemm_starts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. Maybe we could if we know the offsets are constant, but that data is on the device right?
| scheduler.raster_order = RasterOrderOptions::AlongM; | ||
| hw_info.device_id = a.get_device(); | ||
| static std::unordered_map<int, int> cached_sm_counts; | ||
| if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { | ||
| cached_sm_counts[hw_info.device_id] = | ||
| cutlass::KernelHardwareInfo::query_device_multiprocessor_count( | ||
| hw_info.device_id); | ||
| } | ||
| hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we don't need to set hw_info.max_active_clusters since the cluster size is 1,1,1. If we make this adjustable I guess we'd need to update that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW I think we can also just call KernelHardwareInfo<GemmKernel>::make_kernel_hardware_info which will automatically initialize these.
* Add support for bf16 and fp16 for Issue 5007
21f32c9 to
34cc5f5
Compare
|
!build |
)" This reverts commit 515a337.