Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7f634f5
WIP
naoyam Jul 7, 2025
c654486
cleanup
naoyam Jul 7, 2025
8ef8e3e
cleanup
naoyam Jul 7, 2025
549e407
enable codegen of argsort+scatter
naoyam Jul 7, 2025
ea81050
Use IterDomain::merge instead of manually creating a Merge
naoyam Jul 8, 2025
b32ccf3
Merge branch 'main' into simplify_flatten
naoyam Jul 8, 2025
8003a94
Convert indexPutAccumulate to scatter when possible
naoyam Jul 8, 2025
21054f0
Merge remote-tracking branch 'origin/simplify_flatten' into scatter
naoyam Jul 8, 2025
2a5379c
enable codegen of compute_problem_sizes
naoyam Jul 8, 2025
bc6020d
remove old test
naoyam Jul 8, 2025
d2d127b
scatter with shmem
naoyam Jul 8, 2025
6a8c74e
cleanup
naoyam Jul 9, 2025
dce9246
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 9, 2025
f725ad9
cleanup
naoyam Jul 9, 2025
96bb1d0
cleanup
naoyam Jul 9, 2025
0fc4aff
cleanup
naoyam Jul 9, 2025
d8291fc
fix
naoyam Jul 9, 2025
59d73b2
cleanup
naoyam Jul 9, 2025
09617ea
fix
naoyam Jul 9, 2025
2fc5184
test fix
naoyam Jul 9, 2025
ee36099
Moved the change of the loop domain to a scheduling routine
naoyam Jul 9, 2025
fd2b83b
bug fix
naoyam Jul 10, 2025
6cd1c3b
cleanup
naoyam Jul 10, 2025
504b3fe
Merge branch 'main' into scatter
naoyam Jul 25, 2025
708db3d
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 28, 2025
537bced
WIP
naoyam Jul 28, 2025
4193f9b
simplify
naoyam Jul 29, 2025
8f15f70
cleanup
naoyam Jul 29, 2025
02ae364
cleanup
naoyam Jul 29, 2025
fd18ff3
cleanup
naoyam Jul 29, 2025
b63a86f
IdModel test
naoyam Jul 29, 2025
25f1b31
cleanup
naoyam Jul 30, 2025
fa9b895
cleanup
naoyam Jul 30, 2025
f08d4ed
update
naoyam Jul 30, 2025
8823a4c
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 30, 2025
e293562
merge fix
naoyam Jul 30, 2025
16c359f
cleanup
naoyam Aug 7, 2025
7d0cc96
Merge branch 'main' into scatter
naoyam Aug 7, 2025
650bf85
interface change
naoyam Aug 7, 2025
95f361e
override indexing fix
naoyam Aug 8, 2025
333f90b
Merge remote-tracking branch 'origin/main' into issue_4929
naoyam Aug 8, 2025
aa3005a
cleanup
naoyam Aug 11, 2025
0c0bf12
Merge remote-tracking branch 'origin/main' into issue_4929
naoyam Aug 11, 2025
36a78c2
comments
naoyam Aug 11, 2025
68c955b
PR feedback
naoyam Aug 11, 2025
bc2aedd
simplify the repro
naoyam Aug 12, 2025
d11803f
fix
naoyam Aug 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions csrc/device_lower/analysis/sync_information.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,53 @@ bool allowIncoherentDependency(
return false;
}

// Check if an iter domain of a tensor is a subject of a scatter
// op. Specifically, if the given expr is a scatter op using the given
// tensor as its input, returns true if the given iter domain is
// derived from the scattered logical iter domain.
bool isConsumedByScatter(TensorView* tv, IterDomain* id, Expr* consumer_expr) {
auto scatter = dynamic_cast<ScatterOp*>(consumer_expr);
if (scatter == nullptr || scatter->in() != tv) {
return false;
}

auto logical_scatter_dim =
TensorDomain::noReductions(tv->getLogicalDomain()).at(scatter->dim());
return DependencyCheck::isDependencyOf(logical_scatter_dim, id);
}

// Check if an iter domain of a tensor is an output of a scatter
// op. All non-scattered IDs should be derived from the non-scattered
// logical IDs. If the given ID is not found in the non-scattered ID
// set, it must be produced by the scatter. Note that we can't just do
// isDependencyOf like isConsumedByScatter since the given ID has no
// dependency with any of the logical IDs of the given tensor since
// the loop domain is set by the index tensor.
bool isProducedByScatter(TensorView* tv, IterDomain* id) {
auto scatter = dynamic_cast<ScatterOp*>(tv->definition());
if (scatter == nullptr) {
return false;
}

auto logical_scatter_dim =
TensorDomain::noReductions(tv->getLogicalDomain()).at(scatter->dim());

std::unordered_set<Val*> non_scatter_logical_ids;
std::ranges::copy_if(
tv->getLogicalDomain(),
std::inserter(non_scatter_logical_ids, non_scatter_logical_ids.end()),
[&](IterDomain* logical_id) {
return logical_id != logical_scatter_dim;
});

auto all_non_scatter_ids = DependencyCheck::getAllValsBetween(
non_scatter_logical_ids,
{tv->getLoopDomain().begin(), tv->getLoopDomain().end()});

return std::ranges::find(all_non_scatter_ids, id) ==
all_non_scatter_ids.end();
}

} // namespace

SyncMap::SyncMap(Fusion* fusion) {
Expand Down Expand Up @@ -370,14 +417,17 @@ SyncMap::SyncMap(Fusion* fusion) {
// are mapped by the best effort replay. (See
// NVFuserTest.RAWSync for a concrete repro).

// Case 1
// Case 1. Note that indexing through scatter needs to be
// excluded due to its indirect indexing.
const auto& id_model = GpuLower::current()->idModel();
auto producer_loop_id = getLoopPromotion(p_id, id_model);
auto consumer_loop_id = getLoopPromotion(c_id, id_model);
const auto& indexing_traveral_graph =
id_model.idGraph(TensorIndexer::traversalGraphType());
if (indexing_traveral_graph.disjointValSets().strictAreMapped(
producer_loop_id, consumer_loop_id)) {
producer_loop_id, consumer_loop_id) &&
!isConsumedByScatter(producer, p_id, expr) &&
!isProducedByScatter(producer, p_id)) {
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
auto index_info = computeIndex(
expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr));
for (const auto& [indexed_id, index] : override_index) {
index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index);
index_info.index_map[traversalGraph().toGroup(indexed_id)] = index;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a trivial bug fix. override_index didn't actually override existing mappings because of the use of emplace.

}
const auto& index_map = index_info.index_map;
auto replacement_map = getIndexReplacementMap(
Expand Down
85 changes: 85 additions & 0 deletions tests/cpp/test_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,89 @@ TEST_F(ScatterTest, CacheAfter) {
testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);
}

// Repro of #4929. In ScatterOp, the logical domain of the output tensor is not
// automatically mapped with its loop domain, but it's possible they
// happen to be mapped. Make sure proper syncronizations are inserted
// even in that case.
TEST_F(ScatterTest, MappedLogicalAndLoop) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

const int64_t m = 8;

auto tv0 = makeContigConcreteTensor({m}, DataType::Int);
fusion.addInput(tv0);
auto tv1 = makeContigConcreteTensor({m}, DataType::Int);
fusion.addInput(tv1);

auto tv2 = set(tv1);
auto tv3 = arange(IrBuilder::create<Val>(8));
auto tv4 = scatter(tv2, 0, tv0, tv3);
auto tv5 = set(tv4);
fusion.addOutput(tv5);

// At this point, tv4's loop ID is not mapped with its sole logical
// ID but mapped with tv0's logical ID. This means that the loop
// ID is not mapped with the loop ID of the input of the op,
// tv2. When parallelized, this difference of the loop IDs should
// cause the sync analysis to flag a potential RAW race. However, it
// is possible they happen to be mapped, e.g., by an additional op
// like below:
auto tv6 = add(tv0, tv1);
fusion.addOutput(tv6);

// The binary add op maps the logical domains of tv0 and tv1, which
// in turn maps the loop domain of tv4 with its logical domain. Make
// sure that the proper synchronizations are inserted even in cases
// like this.

for (auto tv : fusion.allTvs()) {
tv->axis(0)->parallelize(ParallelType::TIDx);
}

tv2->setMemoryType(MemoryType::Shared);
tv2->setAllocationDomain(tv2->getLogicalDomain(), true);
tv4->setMemoryType(MemoryType::Shared);
tv4->setAllocationDomain(tv4->getLogicalDomain(), true);

auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
auto t0 = at::randperm(m, options);
auto t1 = at::randint(0, 100, {m}, options);

KernelExecutor ke;
ke.compile(&fusion, {t0, t1});
auto outputs = ke.run({t0, t1});

testValidate(&fusion, outputs, {t0, t1}, __LINE__, __FILE__);

// There must be a block sync both before and after the scatter
// op.
bool pre_scatter_sync_found = false;
bool scatter_found = false;
bool post_scatter_sync_found = false;
for (const auto expr :
KernelExprVisitor::getAllExprs(ke.compiledKernel()->kernel())) {
if (auto scatter = dynamic_cast<ScatterOp*>(expr)) {
EXPECT_TRUE(pre_scatter_sync_found)
<< "Sync before scatter not found: " << scatter->toString();
scatter_found = true;
}
if (auto sync = dynamic_cast<kir::BlockSync*>(expr)) {
if (!scatter_found) {
EXPECT_FALSE(pre_scatter_sync_found)
<< "Only one sync before scatter expected: " << sync->toString();
pre_scatter_sync_found = true;
} else {
EXPECT_FALSE(post_scatter_sync_found)
<< "Only one sync after scatter expected: " << sync->toString();
post_scatter_sync_found = true;
}
}
}
EXPECT_TRUE(pre_scatter_sync_found) << "Sync before scatter not found";
EXPECT_TRUE(scatter_found) << "Scatter not found in Kernel";
EXPECT_TRUE(post_scatter_sync_found) << "Sync after scatter not found";
}

} // namespace nvfuser