Skip to content

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Aug 26, 2025

Status: generates ITE tree but kernel hangs. TwoAsyncWarps test runs successfully with NVFUSER_ENABLE=kernel_debug so there's a race that's mitigated by slowing down the kernel.

Generated kernel:

// Codegen generated code
template <nvfuser_index_t in0>
__device__ __inline__ void decreaseRegisters() {
  asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n"::"n"(in0));
}
template <nvfuser_index_t in0>
__device__ __inline__ void increaseRegisters() {
  asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n"::"n"(in0));
}

__global__ void __launch_bounds__(/*maxThreadsPerBlock=*/384, /*minBlocksPerMultiprocessor=*/1) nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 2, 2> T1, Tensor<float, 2, 2> T4) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  nvfuser_index_t i0;
  i0 = ceilDiv(T0.logical_size[0LL], 2);
  nvfuser_index_t i1;
  i1 = ceilDiv((ceilDiv(T0.logical_size[1LL], 128)), 2);
  nvfuser_index_t i2;
  i2 = 4 * T0.logical_size[1LL];
  uint32_t i3;
  i3 = __to_uint32(i2);
  nvfuser_index_t i4;
  i4 = (T0.logical_size[1LL] * i0) * ((nvfuser_index_t)blockIdx.x);
  float* ptr5;
  ptr5 = T0.data + i4;
  float* T3 = reinterpret_cast<float*>(array + smem_offset + (((((T0.logical_size[1LL] * 2) * 4) + 128) + 127) & -128));
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 128);
  uint32_t i6;
  i6 = toSmem(T2);
  float* ptr7;
  ptr7 = T1.data + i4;
  uint32_t i8;
  i8 = toSmem(T3);
  nvfuser_index_t i9;
  i9 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)threadIdx.y));
  nvfuser_index_t i10;
  i10 = i9 + i4;
  bool b11;
  b11 = ((nvfuser_index_t)threadIdx.x) < 32ULL;
  bool b12;
  b12 = ((nvfuser_index_t)threadIdx.y) == 0ULL;
  bool b13;
  b13 = ((nvfuser_index_t)threadIdx.z) == 0ULL;
  bool b14;
  b14 = ((nvfuser_index_t)threadIdx.y) == 2;
  nvfuser_index_t i15;
  i15 = ((nvfuser_index_t)threadIdx.x) / 32;
  bool b16;
  b16 = (((nvfuser_index_t)threadIdx.x) / 32ULL) == 0ULL;
  nvfuser_index_t i17;
  i17 = i0 * ((nvfuser_index_t)blockIdx.x);
  bool b18;
  b18 = ((nvfuser_index_t)threadIdx.y) < 2;
  uint64_t* T5 = reinterpret_cast<uint64_t*>(array + smem_offset + 0);
  #pragma unroll
  for(nvfuser_index_t i19 = 0; i19 < 2; ++i19) {
    if ((((Hopper::electSync(4294967295U) && b11) && b12) && b13)) {
      mbarrier::init(toSmem((&T5[i19])), 2U);
    }
  }
  #pragma unroll
  for(nvfuser_index_t i20 = 0; i20 < 2; ++i20) {
    if ((((Hopper::electSync(4294967295U) && b11) && b12) && b13)) {
      mbarrier::init(toSmem((&T5[(i20 + 2LL)])), 256U);
    }
  }
  __syncthreads();
  if ((b14 && (i15 == 0))) {
    decreaseRegisters<152>();
    #pragma unroll 1
    for(nvfuser_index_t i21 = 0; i21 < i0; ++i21) {
      if (((b16 && Hopper::electSync(4294967295U)) && ((i17 + i21) < T0.logical_size[0LL]))) {
        mbarrier::waitParity(toSmem((&T5[((i21 % 2) + 2LL)])), __to_uint32(((i21 / 2) % 2)));
        mbarrier::arriveExpectTX(toSmem((&T5[(i21 % 2)])), i3);
        Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (ptr5 + (T0.logical_size[1LL] * i21)), i3, toSmem((&T5[(i21 % 2)])) }), (i6 + (i2 * (i21 % 2))));
      }
    }
    return;
  } else {
    if ((b14 && (i15 == 1))) {
      decreaseRegisters<152>();
      #pragma unroll 1
      for(nvfuser_index_t i21 = 0; i21 < i0; ++i21) {
        if ((b16 && Hopper::electSync(4294967295U))) {
          mbarrier::waitParity(toSmem((&T5[((i21 % 2) + 2LL)])), __to_uint32(((i21 / 2) % 2)));
          mbarrier::arriveExpectTX(toSmem((&T5[(i21 % 2)])), i3);
          Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (ptr7 + (T0.logical_size[1LL] * i21)), i3, toSmem((&T5[(i21 % 2)])) }), (i8 + (i2 * (i21 % 2))));
        }
      }
      return;
    } else {
      increaseRegisters<176>();
      #pragma unroll
      for(nvfuser_index_t i22 = 0; i22 < 2; ++i22) {
        mbarrier::arrive(toSmem((&T5[(i22 + 2LL)])));
      }
      increaseRegisters<176>();
      #pragma unroll
      for(nvfuser_index_t i23 = 0; i23 < 2; ++i23) {
        mbarrier::arrive(toSmem((&T5[(i23 + 2LL)])));
      }
      #pragma unroll 1
      for(nvfuser_index_t i24 = 0; i24 < i0; ++i24) {
        nvfuser_index_t i25;
        i25 = i9 + (T0.logical_size[1LL] * (i24 % 2));
        nvfuser_index_t i26;
        i26 = i10 + (T0.logical_size[1LL] * i24);
        bool b27;
        b27 = (i17 + i24) < T0.logical_size[0LL];
        bool b28;
        b28 = b18 && b27;
        if (b27) {
          mbarrier::waitParity(toSmem((&T5[(i24 % 2)])), __to_uint32(((i24 / 2) % 2)));
        }
        #pragma unroll 1
        for(nvfuser_index_t i29 = 0; i29 < i1; ++i29) {
          nvfuser_index_t i30;
          i30 = 256 * i29;
          nvfuser_index_t i31;
          i31 = i25 + i30;
          if ((b28 && ((i9 + i30) < T0.logical_size[1LL]))) {
            T4[(i26 + i30)]
              = T2[i31]
              * T3[i31];
          }
        }
        mbarrier::arrive(toSmem((&T5[((i24 % 2) + 2LL)])));
      }
    }
  }
  #pragma unroll
  for(nvfuser_index_t i32 = 0; i32 < 2; ++i32) {
    if ((((Hopper::electSync(4294967295U) && b11) && b12) && b13)) {
      mbarrier::inval(toSmem((&T5[(i32 + 2LL)])));
    }
  }
  #pragma unroll
  for(nvfuser_index_t i33 = 0; i33 < 2; ++i33) {
    if ((((Hopper::electSync(4294967295U) && b11) && b12) && b13)) {
      mbarrier::inval(toSmem((&T5[i33])));
    }
  }
}

I fixed the obvious things like where the compute warp is arriving twice and b14 should not be in the predicates, and tried a few more things. Still investigating...

Outlook:

  • I need to remove some more special casing for warp selection in TMA predication in case we detect we're in a warp specialized loop. Also we should go ahead and select a thread with elect sync at the warp level if all its async ops are single-thread instead of waiting to do that predicate in the inner loop.
  • The current tests don't demonstrate one async warp consuming a buffer from another async warp but I will add one soon with two async warps chained together.
  • This continues the perspective that the circular buffering attribute of a tv dictates which warp the definition is executed in. That's fine for now but at some point we could have a synchronous producer of a circular buffer (think stmatrix) that is consumed by an async op from another warp. Currently TMA store uses cp.async groups instead of mbarriers so this is not a pressing concern but if it's needed someday then we might need to change some of the machinery that already exists to accomodate it.
  • Mbarrier management should be more explicit. One slot full and one slot free mbarrier per async warp makes sense to me, we just need to be careful with initialization, invalidation etc.
  • Predication should take into account the context where the expression lives. That means we may need to track which warp specialized Scope, if any, an Expr is in.

@jacobhinkle jacobhinkle requested a review from rdspring1 August 26, 2025 14:28
Copy link

Description

  • Enable multi-role warp specialization with multiple async warps

  • Support distinct async warp roles in circular buffering

  • Implement per-warp predicate generation in ITE tree

  • Fix register sharing across multiple async warps


Changes walkthrough 📝

Relevant files
Enhancement
compute_at_map.cpp
Support multiple async warps in sibling ID mapping             

csrc/compute_at_map.cpp

  • Updated getAsyncWarpSiblingIds to iterate over all async warps
  • Added validation to ensure stage_slice_position is used by only one
    async warp
  • Removed early return on first async warp, now processes all warps
  • +20/-16 
    circular_buffer.cpp
    Refactor async warp creation with multi-warp support         

    csrc/device_lower/analysis/circular_buffer.cpp

  • Introduced isAsyncProducer helper to identify async ops
  • Rewrote createAsyncWarps to handle multiple warps via async_warp ID
  • Validated consistent stage_slice_position within each warp
  • +52/-58 
    circular_buffer.cpp
    Generate ITE tree for multiple async warps                             

    csrc/device_lower/pass/circular_buffer.cpp

  • Introduced AsyncLoopInfo to track per-warp async ops and options
  • Updated insertion map to support multiple warps per loop
  • Implemented ITE tree with per-warp predicates and scopes
  • Added support for register sharing across multiple warps
  • +302/-81
    id_model.cpp
    Support multi-warp inlining info construction                       

    csrc/id_model/id_model.cpp

  • Updated buildAsyncWarpInliningInfo to process all async warps
  • Removed restriction to single async warp
  • Preserved sibling ID mapping per warp
  • +34/-35 
    interface_nodes.h
    Extend WarpSpecialized for multi-warp support                       

    csrc/ir/interface_nodes.h

  • Added async_warp field to WarpSpecialized struct
  • Documented constraints for multi-warp register sharing
  • Preserved existing num_registers and stage_slice_position fields
  • +8/-0     
    Tests
    test_circular_buffering.cpp
    Add tests for multi-async-warp circular buffering               

    tests/cpp/test_circular_buffering.cpp

  • Added SingleDimTwoAsyncWarps test for basic multi-warp case
  • Added TwoAsyncWarps test with TMA ops in separate warps
  • Verified kernel execution and validation
  • +173/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function createArrivesForWar includes debug print statements using std::cout that log the mbarrier map. These should not be present in production code as they can affect performance and clutter logs.

    std::cout << "mbarrierMap:\n";
    for (const auto& [expr, tv] : GpuLower::current()->mbarrierMap()) {
      std::cout << "  " << expr->toString() << "   ->   " << tv->toString()
                << std::endl;
    }
    Logic Error

    The function getAsyncWarpSiblingIds returns early within a loop over async warps, which may cause incorrect behavior if multiple async warps exist and the first one does not meet the short-circuit conditions. The logic appears to assume only one relevant async warp, but the loop suggests multiple could be processed.

    for (const AsyncWarp& async_warp : async_warps) {
      // short-circuit: no sibling relationships to map.
      if (async_warp.tvs.size() == 1) {
        continue;
      }
    
      // short-circuit: stage_slice_position is not used.
      if (async_warp.stage_slice_position == -1) {
        continue;
      }
    
      // If there are multiple async warps, check that only one of them uses
      // stage_slice_position.
      // TODO: I don't think we'll ever need stage_slice_position with multiple
      // async warps
    
      NVF_ERROR(
          async_warps.size() == 1,
          "Multi-role specialization supported only when stage_slice_position is "
          "not used");
    
      return getSiblingIds(async_warp);
    }
    return {};
    Performance Risk

    Register sharing is enabled based solely on the presence of num_registers in the first async warp's options, without validating consistency across all async warps. This could lead to incorrect register allocation if different warps have conflicting register requirements.

    bool enable_register_sharing =
        std::get<WarpSpecialized>(warp_info.cb_options.type)
            .num_registers.has_value();
    GpuLower::current()->kernel()->manage(
        "enable_register_sharing", enable_register_sharing);

    @jacobhinkle
    Copy link
    Collaborator Author

    I think the problem is that we are using the same "slot full" mbarrier for both async warps. If one warp does the expect_tx and the load completed before the other warp has issued its expect_tx then the mbarrier will be arrived which throws off the expected parity bit leading to the deadlock. Instead I think it's best to have one slot full barrier and one slot empty barrier per async warp. Consumers that use buffers from two async warps like in this example will just need to wait for both mbarriers before proceeding.

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What would you like me to review in this PR? It doesn't seem ready to merge.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants