Skip to content

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Aug 8, 2025

Fixes #4929.

This is a follow-up bug fix (#4742 (comment)).

There are actually two bugs. One is index overriding, which caused the overriding replacement to fail:

Wrong: T4[((nvfuser_index_t)threadIdx.x)] = T3[0];
Correct: T4[__to_index(T0[((nvfuser_index_t)threadIdx.x)])] = T3[0];

Another bug is missing RAW syncs. The sync analysis needed to be extended to consider indirect indexing.

@naoyam
Copy link
Collaborator Author

naoyam commented Aug 8, 2025

!test --diff

Copy link

github-actions bot commented Aug 8, 2025

Review updated until commit d11803f

Description

  • Fix index override in scatter operations

  • Add RAW synchronization for indirect indexing

  • Correct loop domain mapping in scatter ops

  • Enhance sync analysis for scatter dependencies


Changes walkthrough 📝

Relevant files
Bug fix
sync_information.cpp
Add scatter-aware synchronization analysis                             

csrc/device_lower/analysis/sync_information.cpp

  • Added isConsumedByScatter to detect scatter input dependencies
  • Added isProducedByScatter to identify scatter output domains
  • Updated sync analysis to exclude scatter indexing cases
  • Fixed missing RAW syncs in indirect indexing scenarios
  • +52/-2   
    indexing.cpp
    Fix index map override in tensor indexing                               

    csrc/id_model/indexing.cpp

  • Changed index_map insertion to use direct assignment
  • Fixed potential override replacement failure in indexing
  • +1/-1     
    Tests
    test_scatter.cpp
    Add test for scatter sync with mapped domains                       

    tests/cpp/test_scatter.cpp

  • Added test MappedLogicalAndLoop for issue Missing sync when scatter logical domain happens to be mapped with loop domain #4929
  • Validates pre- and post-scatter block syncs
  • Confirms correct synchronization with mapped domains
  • +85/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The function isProducedByScatter uses a logic that may incorrectly identify an IterDomain as produced by scatter if it is not present in the non-scattered ID set. This could lead to missing RAW synchronizations in cases where indirect indexing is involved but not actually scattered, potentially causing race conditions.

    bool isProducedByScatter(TensorView* tv, IterDomain* id) {
      auto scatter = dynamic_cast<ScatterOp*>(tv->definition());
      if (scatter == nullptr) {
        return false;
      }
    
      auto logical_scatter_dim =
          TensorDomain::noReductions(tv->getLogicalDomain()).at(scatter->dim());
    
      std::unordered_set<Val*> non_scatter_logical_ids;
      std::ranges::copy_if(
          tv->getLogicalDomain(),
          std::inserter(non_scatter_logical_ids, non_scatter_logical_ids.end()),
          [&](IterDomain* logical_id) {
            return logical_id != logical_scatter_dim;
          });
    
      auto all_non_scatter_ids = DependencyCheck::getAllValsBetween(
          non_scatter_logical_ids,
          {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
    
      return std::ranges::find(all_non_scatter_ids, id) ==
          all_non_scatter_ids.end();
    }
    Index Override Logic

    The use of assignment (=) instead of emplace in index_info.index_map during index override may unintentionally override existing mappings without checking for duplicates, which could lead to incorrect indexing behavior in complex tensor expressions.

    index_info.index_map[traversalGraph().toGroup(indexed_id)] = index;
    Test Coverage

    The new test MappedLogicalAndLoop verifies synchronization around scatter operations, but does not validate the correctness of the actual data movement or indexing in the presence of indirect access patterns, which is critical for ensuring the fix addresses both reported bugs.

    // Repro of #4929. In ScatterOp, the logical domain of the output tensor is not
    // automatically mapped with its loop domain, but it's possible they
    // happen to be mapped. Make sure proper syncronizations are inserted
    // even in that case.
    TEST_F(ScatterTest, MappedLogicalAndLoop) {
      auto fusion_ptr = std::make_unique<Fusion>();
      Fusion& fusion = *fusion_ptr.get();
      FusionGuard fg(&fusion);
    
      const int64_t m = 8;
    
      auto tv0 = makeContigConcreteTensor({m}, DataType::Int);
      fusion.addInput(tv0);
      auto tv1 = makeContigConcreteTensor({m}, DataType::Int);
      fusion.addInput(tv1);
    
      auto tv2 = set(tv1);
      auto tv3 = arange(IrBuilder::create<Val>(8));
      auto tv4 = scatter(tv2, 0, tv0, tv3);
      auto tv5 = set(tv4);
      fusion.addOutput(tv5);
    
      // At this point, tv4's loop ID is not mapped with its sole logical
      // ID but mapped with tv0's logical ID. This means that the loop
      // ID is not mapped with the loop ID of the input of the op,
      // tv2. When parallelized, this difference of the loop IDs should
      // cause the sync analysis to flag a potential RAW race. However, it
      // is possible they happen to be mapped, e.g., by an additional op
      // like below:
      auto tv6 = add(tv0, tv1);
      fusion.addOutput(tv6);
    
      // The binary add op maps the logical domains of tv0 and tv1, which
      // in turn maps the loop domain of tv4 with its logical domain. Make
      // sure that the proper synchronizations are inserted even in cases
      // like this.
    
      for (auto tv : fusion.allTvs()) {
        tv->axis(0)->parallelize(ParallelType::TIDx);
      }
    
      tv2->setMemoryType(MemoryType::Shared);
      tv2->setAllocationDomain(tv2->getLogicalDomain(), true);
      tv4->setMemoryType(MemoryType::Shared);
      tv4->setAllocationDomain(tv4->getLogicalDomain(), true);
    
      auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
      auto t0 = at::randperm(m, options);
      auto t1 = at::randint(0, 100, {m}, options);
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0, t1});
      auto outputs = ke.run({t0, t1});
    
      testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
    
      // There must be a block sync both before and after the scatter
      // op.
      bool pre_scatter_sync_found = false;
      bool scatter_found = false;
      bool post_scatter_sync_found = false;
      for (const auto expr :
           KernelExprVisitor::getAllExprs(ke.compiledKernel()->kernel())) {
        if (auto scatter = dynamic_cast<ScatterOp*>(expr)) {
          EXPECT_TRUE(pre_scatter_sync_found)
              << "Sync before scatter not found: " << scatter->toString();
          scatter_found = true;
        }
        if (auto sync = dynamic_cast<kir::BlockSync*>(expr)) {
          if (!scatter_found) {
            EXPECT_FALSE(pre_scatter_sync_found)
                << "Only one sync before scatter expected: " << sync->toString();
            pre_scatter_sync_found = true;
          } else {
            EXPECT_FALSE(post_scatter_sync_found)
                << "Only one sync after scatter expected: " << sync->toString();
            post_scatter_sync_found = true;
          }
        }
      }
      EXPECT_TRUE(pre_scatter_sync_found) << "Sync before scatter not found";
      EXPECT_TRUE(scatter_found) << "Scatter not found in Kernel";
      EXPECT_TRUE(post_scatter_sync_found) << "Sync after scatter not found";
    }

    expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr));
    for (const auto& [indexed_id, index] : override_index) {
    index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index);
    index_info.index_map[traversalGraph().toGroup(indexed_id)] = index;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is a trivial bug fix. override_index didn't actually override existing mappings because of the use of emplace.

    @naoyam naoyam marked this pull request as ready for review August 11, 2025 19:06
    @naoyam naoyam requested a review from jjsjann123 August 11, 2025 19:11
    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Aug 11, 2025

    !test

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Aug 12, 2025

    68c955b revealed a bug with isProducedByScatter. It should be fixed by the last commit.

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Aug 12, 2025

    !test

    @naoyam naoyam merged commit b808ad8 into main Aug 12, 2025
    51 of 53 checks passed
    @naoyam naoyam deleted the issue_4929 branch August 12, 2025 03:11
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    Missing sync when scatter logical domain happens to be mapped with loop domain
    2 participants