Skip to content

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Sep 3, 2025

The predicate elimination check with shared memory is skipped when there's no producer tensor, which is wrong. This happens, for example, with factory ops. Please see the new test.

Here's the generated kernel with the repro using ToT:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T3) {
  alignas(16) extern __shared__ char array[];
  const unsigned smem_offset = 0;
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 0);
  float f0;
  f0 = __to_float(1);
  float f1;
  f1 = __to_float(1);
  T2[((nvfuser_index_t)threadIdx.x)] = 0.000000000e+00f;
  if ((((nvfuser_index_t)threadIdx.x) < 8)) {
    T3[((nvfuser_index_t)threadIdx.x)]
      = T2[((nvfuser_index_t)threadIdx.x)]
      + f0;
  }
  T1[((nvfuser_index_t)threadIdx.x)]
    = T0[((nvfuser_index_t)threadIdx.x)]
    + f1;
}

Notice that the predicate of the zero assignment to T2 is eliminated, which is wrong.

TODO: It assumes the allocation of a shared memory tensor is always based on the loop domain, which may not be true.

@naoyam
Copy link
Collaborator Author

naoyam commented Sep 3, 2025

!test --diff

Copy link

github-actions bot commented Sep 3, 2025

Review updated until commit 5eb6ad9

Description

  • Fix predicate elimination for shared memory tensors

  • Handle cases without producer tensors correctly

  • Improve shared memory access safety checks

  • Add test for full shared memory predicate case


Changes walkthrough 📝

Relevant files
Bug fix
predicate_elimination.cpp
Refactor and fix shared memory predicate logic                     

csrc/device_lower/analysis/predicate_elimination.cpp

  • Renamed isSharedMemory to isSharedMemoryTensor and updated signature
    to accept Val*
  • Converted needSharedMemPredicate logic into
    needsPredicateSharedMemAccess within a single function
  • Expanded checks to consider all inputs and outputs using
    isSharedMemoryTensor
  • Added safeguards for reduction ops, unroll/unswitch loops, and buffer
    reuse conflicts
  • Fixed logic that incorrectly skipped checks when no producer existed
  • +82/-77 
    Tests
    test_predicate_elimination.cpp
    Add test for shared memory predicate                                         

    tests/cpp/test_predicate_elimination.cpp

  • Added new test FullSharedMem to verify predicate behavior
  • Tests zero-initialization of shared memory tensor with no producer
  • Confirms predicate is preserved even without producer tensor
  • Validates kernel execution and correctness
  • +36/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The condition std::ranges::any_of(expr->inputs(), isSharedMemoryTensor) in needsPredicateSharedMemAccess is retained to preserve existing behavior, but the comment indicates it may rely on an unintended effect. This could mask a deeper issue in shared memory tensor handling, particularly around predicate elimination and buffer aliasing, and should be validated to ensure correctness beyond just passing tests.

    // TODO: This condition should not need
    // std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but
    // it's there to keep the existing behavior as of PR #5107.
    // Specifically, if not used,
    // HopperMatmulTest.EpilogueBiasPersistentBroadcastInputs fails
    // due to insufficient shared memory, which happens because T12 is
    // not aliased with T9 if the predicate of T6 is removed. This
    // seems to rely on an unintended effect. In this particular case,
    // T6 is an output of a TMA store, so its input, T9, should not
    // need to be initialized, and thus I think we could remove the
    // predicate but still allow aliasing.
    Performance Concern

    The TODO comment suggests that the current logic prevents predicate removal for unroll and unswitch loops, which could limit optimization opportunities. This restriction may hinder performance in cases where safe predicate removal is possible, and should be evaluated for future enablement.

    // TODO: (Enable in a follow up)
    //  smem predicate removal with init would break unroll and unswitch,
    //  eg. as in issue 1133, so disabling this removal pattern for now.
    if (id->getParallelType() == ParallelType::Unroll ||
        id->getParallelType() == ParallelType::Unswitch) {
      RECORD_AND_RETURN(true);
    }
    Design Limitation

    The function assumes shared memory tensor allocation is always based on loop domains, which the PR author notes may not be true. This assumption could lead to incorrect predicate elimination in more complex scenarios and should be addressed to ensure robustness.

    bool needsPredicateSharedMemAccess(const Expr* expr) {
      DEBUG_PRINT_SCOPE(expr);
    
      // This is initial step to gradually remove predicates around
      //  sharedmem access in suitable situations.
      // Using an additional variable to track the predicate-on reasons
      //  when the predicate around shared mem cannot be removed.
      for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
        // If consumer schedule contains in-exact thread parallel
        //  dimensions, need to predicate against out of bound
        //  shared memory access by out of bound threads.
        //
        // TODO: This condition should not need
        // std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but
        // it's there to keep the existing behavior as of PR #5107.
        // Specifically, if not used,
        // HopperMatmulTest.EpilogueBiasPersistentBroadcastInputs fails
        // due to insufficient shared memory, which happens because T12 is
        // not aliased with T9 if the predicate of T6 is removed. This
        // seems to rely on an unintended effect. In this particular case,
        // T6 is an output of a TMA store, so its input, T9, should not
        // need to be initialized, and thus I think we could remove the
        // predicate but still allow aliasing.
        if (isSharedMemoryTensor(consumer) ||
            std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) {
          if (!isExactParallelSharedMemAccess(consumer)) {
            RECORD_AND_RETURN(true);
          }
        }
    
        // TODO: This condition should not need
        // std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but
        // it's there to keep the existing behavior as of PR #5107.
        if (isSharedMemoryTensor(consumer) ||
            std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) {
          for (auto id : consumer->getLoopDomain()) {
            // TODO: (Enable in a follow up)
            //  smem predicate removal with init would break unroll and unswitch,
            //  eg. as in issue 1133, so disabling this removal pattern for now.
            if (id->getParallelType() == ParallelType::Unroll ||
                id->getParallelType() == ParallelType::Unswitch) {
              RECORD_AND_RETURN(true);
            }
          }
        }
    
        // TODO: (Enable in a follow up)
        //  This cannot yet be removed since smem initialization needs to be
        //  handled specially, e.g. as in smem_reduce test. Will be able to
        //  lift this one once the generic pred removal pass with fusion
        //  traversal is ready.
        if (ir_utils::isReductionOp(consumer->definition()) &&
            std::ranges::any_of(expr->inputs(), [](auto val) {
              return isSharedMemoryTensor(val);
            })) {
          RECORD_AND_RETURN(true);
        }
      }

    for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
    for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
    if (isSharedMemory(producer) || isSharedMemory(consumer)) {
    if (needSharedMemPredicate(producer, consumer)) {
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This check is not done when there's no producer.

    return false;
    }

    bool needsPredicateSharedMemAccess(const Expr* expr) {
    Copy link
    Collaborator Author

    @naoyam naoyam Sep 3, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Changes here are just making sure all checks are performed even when there's no producer.

    It turned out this resulted in a lot of code changes, which I don't feel like going through right now. For now, I tried to keep the current behavior as is, even though in some cases it doesn't seem to make sense, as long as the correctness is preserved.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Sep 3, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Sep 3, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Sep 4, 2025

    !test --diff

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Sep 4, 2025

    !test --diff

    // dimensions, need to predicate against out of bound
    // shared memory access by out of bound threads.
    //
    // TODO: This condition should not need
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Here, I thought we shouldn't need to look at producers, but that resulted in some unexpected changes that includes missed shared memory aliasing.

    }
    }

    // TODO: This condition should not need
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Similarly here, this condition is to keep the current behavior for now.

    @naoyam naoyam marked this pull request as ready for review September 4, 2025 21:14
    @naoyam naoyam requested a review from jacobhinkle September 4, 2025 21:14
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant