Skip to content

Conversation

@rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 added Direct Bindings Python extension with direct mapping to NvFuser CPP objects. Cutlass labels Sep 2, 2025
@github-actions
Copy link

github-actions bot commented Sep 2, 2025

Review updated until commit 34cc5f5

Description

  • Add grouped_mm support for bf16 and fp16 on Blackwell

  • Implement CUDA kernel for grouped GEMM memory layout

  • Add test suite for grouped_mm with multiple configurations

  • Update build system and headers for new kernel


Changes walkthrough 📝

Relevant files
Enhancement
cutlass.cpp
Add grouped_mm Python binding                                                       

python/python_direct/cutlass.cpp

  • Add Python binding for grouped_mm function
  • Expose grouped_mm with full parameter documentation
  • +10/-0   
    group_mm.cu
    Implement grouped_mm CUDA kernel                                                 

    cutlass/group_mm.cu

  • Implement grouped_mm CUDA kernel for Blackwell
  • Support bf16 and fp16 through template specialization
  • Include memory layout and offset computation
  • Validate inputs and handle error cases
  • +524/-0 
    nvf_cutlass.h
    Declare grouped_mm in header                                                         

    cutlass/nvf_cutlass.h

  • Add grouped_mm function declaration
  • Document parameters and return value
  • +26/-0   
    Tests
    test_cutlass_gemm.py
    Add grouped_mm test suite                                                               

    tests/python/direct/test_cutlass_gemm.py

  • Add comprehensive test for grouped_mm
  • Test multiple configs and dtypes (bf16, fp16)
  • Validate output against PyTorch reference
  • +76/-0   
    Formatting
    nvfp4_scaled_group_mm.cu
    Clean up includes and comments                                                     

    cutlass/nvfp4_scaled_group_mm.cu

  • Remove unused includes
  • Fix comment formatting and parameter descriptions
  • +10/-23 
    nvfp4_scaled_mm_blockscale.cu
    Remove unnecessary includes                                                           

    cutlass/nvfp4_scaled_mm_blockscale.cu

    • Remove unused includes
    • Clean up header dependencies
    +0/-4     
    Configuration changes
    CMakeLists.txt
    Include group_mm.cu in build                                                         

    CMakeLists.txt

    • Add group_mm.cu to build system
    +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Potential Overflow Risk

    The kernel get_group_gemm_starts uses threadIdx.x as the expert_id without checking if it is within the valid range of experts, which could lead to out-of-bounds memory access when the number of experts exceeds the thread count.

    int64_t expert_id = threadIdx.x;
    if (expert_id >= gridDim.x * blockDim.x) {
      return;
    }
    // Upcast from int32_t to int64_t to avoid overflow during offset calculations
    int64_t expert_offset = static_cast<int64_t>(expert_offsets[expert_id]);
    int64_t n = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 1]);
    int64_t k = static_cast<int64_t>(problem_sizes_as_shapes[expert_id * 3 + 2]);
    assert((n == N && k == K) && "Unexpected problem sizes");
    Incomplete Validation

    The function validateInputsGroupMm does not validate the ab_strides and c_strides tensors, which are critical for correct memory access patterns in the grouped GEMM operation.

    void validateInputsGroupMm(
        const torch::Tensor& a,
        const torch::Tensor& b,
        const torch::Tensor& problem_sizes,
        const torch::Tensor& expert_offsets) {
      // Check data types
      NVF_CHECK(
          a.scalar_type() == at::ScalarType::BFloat16 ||
              a.scalar_type() == at::ScalarType::Half,
          "Expected BFloat16 or Half for Operand A.")
      NVF_CHECK(
          b.scalar_type() == at::ScalarType::BFloat16 ||
              b.scalar_type() == at::ScalarType::Half,
          "Expected BFloat16 or Half for Operand B.")
    
      // Check CUDA device
      NVF_CHECK(a.is_cuda(), "Expected CUDA tensor for Operand A.")
      NVF_CHECK(b.is_cuda(), "Expected CUDA tensor for Operand B.")
    
      // Check contiguity
      NVF_CHECK(a.is_contiguous(), "Expected contiguous tensor for Operand A.")
      NVF_CHECK(b.is_contiguous(), "Expected contiguous tensor for Operand B.")
    
      // Check shapes
      NVF_CHECK(problem_sizes.dim() == 2, "problem_sizes must be  a 2D tensor");
      NVF_CHECK(
          problem_sizes.size(1) == 3,
          "problem_sizes must have the shape (num_experts, 3)");
      NVF_CHECK(
          problem_sizes.size(0) == expert_offsets.size(0),
          "Number of experts in problem_sizes must match expert_offsets");
      NVF_CHECK(
          problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32.");
    }
    Missing Error Handling

    The Python binding for grouped_mm does not include error handling or validation of input tensor properties beyond what is done in the C++ implementation, potentially allowing invalid inputs to reach the CUDA kernel.

    cutlass.def(
        "grouped_mm",
        &cutlass_kernels::grouped_mm,
        R"(Computes grouped matmul and returns bf16 or fp16 output tensor.
           grouped_mm(Tensor a,
                      Tensor b,
                      Tensor ab_strides,
                      Tensor c_strides,
                      Tensor problem_sizes,
                      Tensor expert_offsets) -> Tensor output)");

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Same comment as in the other PR.
    On top of that, some nitpick about comments that should have dropped all the nvfp4 related stuff that's no longer relevant with 16bit types.

    @rdspring1 rdspring1 force-pushed the cutlass_grouped_gemm_refactor branch from 4fcd1a3 to f43d0a5 Compare September 4, 2025 02:25
    Base automatically changed from cutlass_grouped_gemm_refactor to main September 5, 2025 01:58
    @rdspring1 rdspring1 force-pushed the cutlass_grouped_gemm_bf16 branch 2 times, most recently from 3e2daf1 to 3eba5ae Compare September 5, 2025 23:19
    @rdspring1 rdspring1 requested a review from jjsjann123 September 6, 2025 17:04
    @rdspring1 rdspring1 force-pushed the cutlass_grouped_gemm_bf16 branch from ff39085 to 21f32c9 Compare September 6, 2025 17:09
    @rdspring1
    Copy link
    Collaborator Author

    !test

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    lgtm~

    #include "cutlass/util/reference/host/tensor_compare.h"
    #include "cutlass/util/reference/host/tensor_fill.h"
    #include "cutlass/util/reference/host/tensor_norm.h"
    #include "cutlass/util/tensor_view_io.h"
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks for cleaning up the header~~~

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM. Left some minor comments. These comments apply to the scaled input version as well. I just didn't notice previously.

    Comment on lines +35 to +57
    template <typename T, bool is_single_sm>
    struct KernelTraits;

    // Kernel traits for FP16 output
    template <>
    struct KernelTraits<cutlass::half_t, true> {
    using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(cutlass::half_t)>>;
    using ClusterShape = Shape<_1, _1, _1>;
    using KernelSchedule =
    cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
    using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
    };

    // Kernel traits for BFloat16 output
    template <>
    struct KernelTraits<cutlass::bfloat16_t, true> {
    using MmaTileShape =
    Shape<_128, _256, Int<128 / sizeof(cutlass::bfloat16_t)>>;
    using ClusterShape = Shape<_1, _1, _1>;
    using KernelSchedule =
    cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
    using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
    };
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I guess we could just handle this in a single template definition using sizeof(T) and using

      std::conditional<is_single_sm,
          cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100,
          cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100>::type

    Comment on lines +299 to +311
    run_get_group_gemm_starts(
    a_ptrs,
    b_ptrs,
    out_ptrs,
    a,
    b,
    output,
    expert_offsets,
    problem_sizes,
    M,
    N,
    K,
    stream);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is it normal to need to launch two kernels for every invocation? I would imagine for common use cases we'd have a static shape meaning all the offsets could be precomputed even if they live on the GPU and we'd just change the base data pointer for each invocation. I wonder what is done for small problems or inference use cases..

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes. Can the CutlassExecutor cache based on the shapes to avoid calling the run_get_group_gemm_starts?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Good question. Maybe we could if we know the offsets are constant, but that data is on the device right?

    Comment on lines +326 to +334
    scheduler.raster_order = RasterOrderOptions::AlongM;
    hw_info.device_id = a.get_device();
    static std::unordered_map<int, int> cached_sm_counts;
    if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) {
    cached_sm_counts[hw_info.device_id] =
    cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
    hw_info.device_id);
    }
    hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I assume we don't need to set hw_info.max_active_clusters since the cluster size is 1,1,1. If we make this adjustable I guess we'd need to update that here.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    BTW I think we can also just call KernelHardwareInfo<GemmKernel>::make_kernel_hardware_info which will automatically initialize these.

    @rdspring1 rdspring1 force-pushed the cutlass_grouped_gemm_bf16 branch from 21f32c9 to 34cc5f5 Compare September 10, 2025 21:44
    @rdspring1
    Copy link
    Collaborator Author

    !build

    @rdspring1 rdspring1 merged commit 515a337 into main Sep 10, 2025
    17 checks passed
    @rdspring1 rdspring1 deleted the cutlass_grouped_gemm_bf16 branch September 10, 2025 23:18
    wujingyue added a commit that referenced this pull request Sep 11, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Cutlass Direct Bindings Python extension with direct mapping to NvFuser CPP objects.

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants