diff --git a/csrc/device_lower/analysis/sync_information.cpp b/csrc/device_lower/analysis/sync_information.cpp index 088c52a295f..73553913d95 100644 --- a/csrc/device_lower/analysis/sync_information.cpp +++ b/csrc/device_lower/analysis/sync_information.cpp @@ -91,6 +91,53 @@ bool allowIncoherentDependency( return false; } +// Check if an iter domain of a tensor is a subject of a scatter +// op. Specifically, if the given expr is a scatter op using the given +// tensor as its input, returns true if the given iter domain is +// derived from the scattered logical iter domain. +bool isConsumedByScatter(TensorView* tv, IterDomain* id, Expr* consumer_expr) { + auto scatter = dynamic_cast(consumer_expr); + if (scatter == nullptr || scatter->in() != tv) { + return false; + } + + auto logical_scatter_dim = + TensorDomain::noReductions(tv->getLogicalDomain()).at(scatter->dim()); + return DependencyCheck::isDependencyOf(logical_scatter_dim, id); +} + +// Check if an iter domain of a tensor is an output of a scatter +// op. All non-scattered IDs should be derived from the non-scattered +// logical IDs. If the given ID is not found in the non-scattered ID +// set, it must be produced by the scatter. Note that we can't just do +// isDependencyOf like isConsumedByScatter since the given ID has no +// dependency with any of the logical IDs of the given tensor since +// the loop domain is set by the index tensor. +bool isProducedByScatter(TensorView* tv, IterDomain* id) { + auto scatter = dynamic_cast(tv->definition()); + if (scatter == nullptr) { + return false; + } + + auto logical_scatter_dim = + TensorDomain::noReductions(tv->getLogicalDomain()).at(scatter->dim()); + + std::unordered_set non_scatter_logical_ids; + std::ranges::copy_if( + tv->getLogicalDomain(), + std::inserter(non_scatter_logical_ids, non_scatter_logical_ids.end()), + [&](IterDomain* logical_id) { + return logical_id != logical_scatter_dim; + }); + + auto all_non_scatter_ids = DependencyCheck::getAllValsBetween( + non_scatter_logical_ids, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + + return std::ranges::find(all_non_scatter_ids, id) == + all_non_scatter_ids.end(); +} + } // namespace SyncMap::SyncMap(Fusion* fusion) { @@ -370,14 +417,17 @@ SyncMap::SyncMap(Fusion* fusion) { // are mapped by the best effort replay. (See // NVFuserTest.RAWSync for a concrete repro). - // Case 1 + // Case 1. Note that indexing through scatter needs to be + // excluded due to its indirect indexing. const auto& id_model = GpuLower::current()->idModel(); auto producer_loop_id = getLoopPromotion(p_id, id_model); auto consumer_loop_id = getLoopPromotion(c_id, id_model); const auto& indexing_traveral_graph = id_model.idGraph(TensorIndexer::traversalGraphType()); if (indexing_traveral_graph.disjointValSets().strictAreMapped( - producer_loop_id, consumer_loop_id)) { + producer_loop_id, consumer_loop_id) && + !isConsumedByScatter(producer, p_id, expr) && + !isProducedByScatter(producer, p_id)) { continue; } diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 409a656f29f..468375b4e41 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -971,7 +971,7 @@ std::pair, std::vector> TensorIndexer:: auto index_info = computeIndex( expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr)); for (const auto& [indexed_id, index] : override_index) { - index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index); + index_info.index_map[traversalGraph().toGroup(indexed_id)] = index; } const auto& index_map = index_info.index_map; auto replacement_map = getIndexReplacementMap( diff --git a/tests/cpp/test_scatter.cpp b/tests/cpp/test_scatter.cpp index 0fe1f894497..904eaac3ded 100644 --- a/tests/cpp/test_scatter.cpp +++ b/tests/cpp/test_scatter.cpp @@ -269,4 +269,89 @@ TEST_F(ScatterTest, CacheAfter) { testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__); } +// Repro of #4929. In ScatterOp, the logical domain of the output tensor is not +// automatically mapped with its loop domain, but it's possible they +// happen to be mapped. Make sure proper syncronizations are inserted +// even in that case. +TEST_F(ScatterTest, MappedLogicalAndLoop) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + const int64_t m = 8; + + auto tv0 = makeContigConcreteTensor({m}, DataType::Int); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor({m}, DataType::Int); + fusion.addInput(tv1); + + auto tv2 = set(tv1); + auto tv3 = arange(IrBuilder::create(8)); + auto tv4 = scatter(tv2, 0, tv0, tv3); + auto tv5 = set(tv4); + fusion.addOutput(tv5); + + // At this point, tv4's loop ID is not mapped with its sole logical + // ID but mapped with tv0's logical ID. This means that the loop + // ID is not mapped with the loop ID of the input of the op, + // tv2. When parallelized, this difference of the loop IDs should + // cause the sync analysis to flag a potential RAW race. However, it + // is possible they happen to be mapped, e.g., by an additional op + // like below: + auto tv6 = add(tv0, tv1); + fusion.addOutput(tv6); + + // The binary add op maps the logical domains of tv0 and tv1, which + // in turn maps the loop domain of tv4 with its logical domain. Make + // sure that the proper synchronizations are inserted even in cases + // like this. + + for (auto tv : fusion.allTvs()) { + tv->axis(0)->parallelize(ParallelType::TIDx); + } + + tv2->setMemoryType(MemoryType::Shared); + tv2->setAllocationDomain(tv2->getLogicalDomain(), true); + tv4->setMemoryType(MemoryType::Shared); + tv4->setAllocationDomain(tv4->getLogicalDomain(), true); + + auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + auto t0 = at::randperm(m, options); + auto t1 = at::randint(0, 100, {m}, options); + + KernelExecutor ke; + ke.compile(&fusion, {t0, t1}); + auto outputs = ke.run({t0, t1}); + + testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__); + + // There must be a block sync both before and after the scatter + // op. + bool pre_scatter_sync_found = false; + bool scatter_found = false; + bool post_scatter_sync_found = false; + for (const auto expr : + KernelExprVisitor::getAllExprs(ke.compiledKernel()->kernel())) { + if (auto scatter = dynamic_cast(expr)) { + EXPECT_TRUE(pre_scatter_sync_found) + << "Sync before scatter not found: " << scatter->toString(); + scatter_found = true; + } + if (auto sync = dynamic_cast(expr)) { + if (!scatter_found) { + EXPECT_FALSE(pre_scatter_sync_found) + << "Only one sync before scatter expected: " << sync->toString(); + pre_scatter_sync_found = true; + } else { + EXPECT_FALSE(post_scatter_sync_found) + << "Only one sync after scatter expected: " << sync->toString(); + post_scatter_sync_found = true; + } + } + } + EXPECT_TRUE(pre_scatter_sync_found) << "Sync before scatter not found"; + EXPECT_TRUE(scatter_found) << "Scatter not found in Kernel"; + EXPECT_TRUE(post_scatter_sync_found) << "Sync after scatter not found"; +} + } // namespace nvfuser