Skip to content

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Sep 4, 2025

#5118 PR3: enable codegen for layout op
#5115 PR2: add layout op runtime function <- this PR
#5114 PR1: add layout op

Runtime function signature.

  template <
      typename T,
      typename Index_T,
      int BLOCK_ROW_OUTER,
      int BLOCK_ROW_INNER,
      int BLOCK_COL,
      int UNROLL_FACTOR>
  __device__ void preprocessGroupedMatmulInputSf(
      T* output,
      const T* input,
      const nvfuser_index_t row_idx,
      const nvfuser_index_t col_idx,
      const Index_T* input_offsets,
      const Index_T* output_offsets,
      const nvfuser_index_t col_size,
      const nvfuser_index_t group_size)

where:
  BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL will be translated from
    BlockScalingFactorLayout e.g. Block128x4 is translated to 32, 4, 4.

This function will be used by codegen for `PreprocessGroupedMatmulInputSf`
  `output` is expected to be the beginning of output buffer,
     indexing will be done inside the function template with help of `row_idx`,
     `col_idx`, `expert_offsets`, `output_offsets` and `col_size`
  Meanwhil, indexing on `input` would have been resolved during device lowering.

Todo for future PRs:
Add vectorization support.

Copy link

github-actions bot commented Sep 4, 2025

Review updated until commit c340720

Description

  • Add runtime function for grouped block layout

  • Support preprocessing inputs for grouped matmul

  • Include block layout in kernel codegen

  • Update kernel summary for new op tracking


Changes walkthrough 📝

Relevant files
Enhancement
kernel.cpp
Track preprocess grouped matmul op                                             

csrc/kernel.cpp

  • Added handler for PreprocessGroupedMatmulInputSf in KernelIrScanner
  • Updates kernel summary to track new op
  • +4/-0     
    compiled_kernel.cpp
    Integrate block layout in codegen                                               

    csrc/runtime/compiled_kernel.cpp

  • Include block_layout.h header
  • Extend _getStructuredCode to accept has_block_layout flag
  • Add block_layout_cu resource if op is present
  • Pass new op flag in getStructuredCode
  • +8/-2     
    block_layout.cu
    Implement grouped block layout runtime                                     

    runtime/block_layout.cu

  • Add new file with preprocessGroupedMatmulInputSf device function
  • Implement swizzled layout indexing via outputOffsetAfterSwizzlePadding
  • Support grouped input handling with offsets and padding
  • Include TODOs for future vectorization
  • +102/-0 
    kernel.h
    Extend kernel summary for new op                                                 

    csrc/kernel.h

  • Add has_preprocess_grouped_matmul_input_sf to KernelSummary
  • Enables tracking of new op in kernel analysis
  • +3/-0     
    Configuration changes
    CMakeLists.txt
    Include block layout in build                                                       

    CMakeLists.txt

  • Add runtime/block_layout.cu to NVFUSER_RUNTIME_FILES
  • Ensures new file is compiled into runtime
  • +1/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The loop in preprocessGroupedMatmulInputSf uses input[i] for loading data, but it should index into the correct position within the current expert's input segment. The current indexing may lead to incorrect memory access.

    for (int i = 0; i < UNROLL_FACTOR && col_idx + i < col_size; ++i) {
      nvfuser_index_t index = outputOffsetAfterSwizzlePadding<
          BLOCK_ROW_OUTER,
          BLOCK_ROW_INNER,
          BLOCK_COL>(c_row_idx, col_idx + i, padded_col_size);
      out_group_offset[index] = input[i];
    }
    Performance Issue

    The function outputOffsetAfterSwizzlePadding is marked with constexpr for stride calculations, but the function itself is not marked as constexpr, which could prevent compile-time evaluation and optimization.

    __device__ nvfuser_index_t outputOffsetAfterSwizzlePadding(
        const nvfuser_index_t row_idx,
        const nvfuser_index_t col_idx,
        const nvfuser_index_t padded_col_size) {
      constexpr nvfuser_index_t BLOCK_ROW_SIZE = BLOCK_ROW_OUTER * BLOCK_ROW_INNER;
    
      /* logical dimension of matrix [ row_size, col_size]
       *
       * while logical domain after padding can be viewed as
       *   [ (row_tile*BLOCK_ROW_INNER*BLOCK_ROW_OUTER), (col_tile*BLOCK_COL) ]
       * where
       *   row_tile = ceilDiv(row_size / BLOCK_ROW_OUTER * BLOCK_ROW_INNER)
       *   col_tile = ceilDiv(col_size / BLOCK_COL)
       */
    
      // we first convert `row_idx` and `col_idx` to the logical index on the 5d
      // tensor.
      nvfuser_index_t row_tile_idx = row_idx / BLOCK_ROW_SIZE;
      nvfuser_index_t row_block_idx = row_idx % BLOCK_ROW_SIZE;
      nvfuser_index_t row_block_inner_idx = row_block_idx / BLOCK_ROW_OUTER;
      nvfuser_index_t row_block_outer_idx = row_block_idx % BLOCK_ROW_OUTER;
      nvfuser_index_t col_tile_idx = col_idx / BLOCK_COL;
      nvfuser_index_t col_block_idx = col_idx % BLOCK_COL;
    
      /* layout for matrix [ row_size, col_size]
       * it is viewed
       *   [row_tile, BLOCK_ROW_INNER, BLOCK_ROW_OUTER, col_tile, BLOCK_COL]
       * then transposed with axis (1, 3)
       *   [row_tile, col_tile, BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL]
       * and then made contiguous
       * So we can compute the corresponding stride for each dimension
       */
      constexpr nvfuser_index_t COL_TILE_STRIDE = BLOCK_ROW_SIZE * BLOCK_COL;
      constexpr nvfuser_index_t BLOCK_ROW_OUTER_STRIDE =
          BLOCK_ROW_INNER * BLOCK_COL;
      constexpr nvfuser_index_t BLOCK_ROW_INNER_STRIDE = BLOCK_COL;
    
      return row_tile_idx * padded_col_size * BLOCK_ROW_SIZE +
          col_tile_idx * COL_TILE_STRIDE +
          row_block_outer_idx * BLOCK_ROW_OUTER_STRIDE +
          row_block_inner_idx * BLOCK_ROW_INNER_STRIDE + col_block_idx;
    }
    Function Signature Mismatch

    The function _getStructuredCode has been updated to include a new parameter has_block_layout, but the parameter is not used consistently in the function body; instead, has_preprocess_grouped_matmul_input_sf is passed to it, which may indicate a mismatch in functionality and naming.

    std::string _getStructuredCode(
        const std::string& kernel_str,
        PrimDataType index_type,
        std::string kernel_name,
        bool has_argsort = false,
        bool has_topk = false,
        bool has_scan = false,
        bool has_block_layout = false) {
      // generating cuda code;
      std::string code = "";
    
      if (has_argsort || has_scan || has_topk) {
        // Internally, CUB uses std::is_pointer, not
        // cuda::std::is_pointer, and it fails to compile as nvrtc does not
        // have <type_traits>. This doesn't seem to be the case with nvcc. A
        // WAR for nvrtc is to provide std::is_pointer as an alias of
        // cuda::std::is_pointer.
        code += "#ifndef __NVCC__\n";
        code += "#include <cuda/std/type_traits>\n";
        code += "namespace std {\n";
        code += "using cuda::std::is_pointer;\n";
        code += "} // namespace std\n";
        code += "#endif\n";
      }
    
      code += defineStdComplex();
      code += std::string("namespace ") + CompiledKernel::kernelNamespace() +
          "{\n" + defineTypes() + defineIndexType(index_type) + kernelPreamble() +
          "} // namespace " + CompiledKernel::kernelNamespace() + "\n";
    
      if (has_argsort || has_topk || has_scan) {
        code += nvfuser_resources::cub_utils_cu;
      }
    
      if (has_argsort) {
        code += nvfuser_resources::argsort_cu;
      }
      if (has_scan) {
        code += nvfuser_resources::scan_cu;
      }
      if (has_topk) {
        code += nvfuser_resources::topk_cu;
      }
      if (has_block_layout) {
        code += nvfuser_resources::block_layout_cu;
      }

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR1_runtime_function branch from 2c992ab to 1d72d32 Compare September 4, 2025 18:28
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR0_ir_node branch from 3ccfbde to 1d72d32 Compare September 4, 2025 18:30
    @jjsjann123 jjsjann123 mentioned this pull request Sep 4, 2025
    @jjsjann123 jjsjann123 changed the title block layout op runtime function added add layout op runtime function Sep 4, 2025
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR1_runtime_function branch from ebd03f4 to 5ec9d72 Compare September 4, 2025 19:12
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 marked this pull request as ready for review September 4, 2025 19:24
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR1_runtime_function branch from 298ea2f to a86508c Compare September 4, 2025 19:45
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR1_runtime_function branch from a86508c to 7c327f6 Compare September 4, 2025 22:11
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR0_ir_node branch from c4e65d3 to 4deb4a9 Compare September 4, 2025 22:32
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR1_runtime_function branch 3 times, most recently from f1709fb to f5b464f Compare September 5, 2025 00:14
    @jjsjann123
    Copy link
    Collaborator Author

    !test

    Runtime function signature.
    
      template <
          typename T,
          typename Index_T,
          int BLOCK_ROW_OUTER,
          int BLOCK_ROW_INNER,
          int BLOCK_COL,
          int UNROLL_FACTOR>
      __device__ void groupedBlockLayout(
          T* output,
          const T* input,
          const nvfuser_index_t row_idx,
          const nvfuser_index_t col_idx,
          const Index_T* expert_offsets,
          const Index_T* output_offsets,
          const nvfuser_index_t col_size,
          const nvfuser_index_t group_size)
    
    where:
      BLOCK_ROW_OUTER, BLOCK_ROW_INNER, BLOCK_COL will be translated from BlockScalingFactorLayout, e.g. Block128x4 is translated to 32, 4, 4.
    
    This function will be used by codegen for `GroupedBlockScalingFactorLayoutOp`
      `output` is expected to be the beginning of output buffer,
         indexing will be done inside the function template with help of `row_idx`,
         `col_idx`, `expert_offsets`, `output_offsets` and `col_size`
      Meanwhil, indexing on `input` would have been resolved during device lowering.
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR0_ir_node branch from a3afee4 to 150d3ee Compare September 5, 2025 17:47
    @jjsjann123 jjsjann123 force-pushed the jj/layout_op_PR1_runtime_function branch from f5b464f to c340720 Compare September 5, 2025 17:48
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant