Skip to content
159 changes: 82 additions & 77 deletions csrc/device_lower/analysis/predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ namespace {
// tensor memory as special type of shared memory. In this file, we use
// the term "shared memory", "smem" to refer to both shared and tensor
// memories.
bool isSharedMemory(TensorView* tv) {
return tv->getMemoryType() == MemoryType::Shared ||
tv->getMemoryType() == MemoryType::Tensor;
bool isSharedMemoryTensor(Val* val) {
auto tv = dynamic_cast<TensorView*>(val);
return tv != nullptr &&
(tv->getMemoryType() == MemoryType::Shared ||
tv->getMemoryType() == MemoryType::Tensor);
}

// Warp primitives are currently limited to un-predicated usage,
Expand Down Expand Up @@ -133,21 +135,73 @@ bool isExactParallelSharedMemAccess(TensorView* tv) {
return true;
}

// Check for conditions where the predicate cannot be removed
// when either producer or consumer is in shared memory.
bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) {
// Indexing is based on consumer loop ids so check the consumer.
bool needsPredicateSharedMemAccess(const Expr* expr) {
Copy link
Collaborator Author

@naoyam naoyam Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes here are just making sure all checks are performed even when there's no producer.

It turned out this resulted in a lot of code changes, which I don't feel like going through right now. For now, I tried to keep the current behavior as is, even though in some cases it doesn't seem to make sense, as long as the correctness is preserved.

DEBUG_PRINT_SCOPE(expr);

// If consumer schedule contains in-exact thread parallel
// dimensions, need to predicate against out of bound
// shared memory access by out of bound threads.
if (!isExactParallelSharedMemAccess(consumer)) {
return true;
// This is initial step to gradually remove predicates around
// sharedmem access in suitable situations.
// Using an additional variable to track the predicate-on reasons
// when the predicate around shared mem cannot be removed.
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
// If consumer schedule contains in-exact thread parallel
// dimensions, need to predicate against out of bound
// shared memory access by out of bound threads.
//
// TODO: This condition should not need
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I thought we shouldn't need to look at producers, but that resulted in some unexpected changes that includes missed shared memory aliasing.

// std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but
// it's there to keep the existing behavior as of PR #5107.
// Specifically, if not used,
// HopperMatmulTest.EpilogueBiasPersistentBroadcastInputs fails
// due to insufficient shared memory, which happens because T12 is
// not aliased with T9 if the predicate of T6 is removed. This
// seems to rely on an unintended effect. In this particular case,
// T6 is an output of a TMA store, so its input, T9, should not
// need to be initialized, and thus I think we could remove the
// predicate but still allow aliasing.
Comment on lines +153 to +160
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why the predicate for T6 interferes with aliasing of T9 and T12. Is it because it changes the loop structure? I would've thought the predicate would just be on the TMA store i.e. at the innermost scope so that we'd still be able to do an alias.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When eliminating a predicate of an expression, we also make sure the input tensors are at least initialized. This may not be necessary as the result of non-predicated execution even if using uninitialized values should not affect anything in the final result on global memory, but I'm not sure if it's considered legal to, for example, use an uninitialized shared memory value. So, suppose we keep initializing them, then this could interfere with the aliasing logic. In this particular case, looks like what happened was when the predicate of T6 was removed, an initialization of T9 was inserted that overlaps with the lifetime of T12, which prevented the aliasing between T12 and T9. It seems that's also described here as well:

https://github.com/NVIDIA/Fuser/blob/main/csrc/device_lower/analysis/predicate_elimination.cpp#L148-L170

Looks like that part was added in #3098.

In this particular case, since the expression of T6 is a TMA store, we should be able to omit the initialization of T9 safely because the hardware takes care of predicating out-of-bounds accesses, which should allow omitting the predicate while still making sure T9 and T12 are aliased. But in general, I feel that predicate elimination should be aware of aliasing, at least for shared memory because the capacity of shared memory is likely more important factor than omitting predicates.

if (isSharedMemoryTensor(consumer) ||
std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) {
if (!isExactParallelSharedMemAccess(consumer)) {
RECORD_AND_RETURN(true);
}
}

// TODO: This condition should not need
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, this condition is to keep the current behavior for now.

// std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but
// it's there to keep the existing behavior as of PR #5107.
if (isSharedMemoryTensor(consumer) ||
std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) {
for (auto id : consumer->getLoopDomain()) {
// TODO: (Enable in a follow up)
// smem predicate removal with init would break unroll and unswitch,
// eg. as in issue 1133, so disabling this removal pattern for now.
if (id->getParallelType() == ParallelType::Unroll ||
id->getParallelType() == ParallelType::Unswitch) {
RECORD_AND_RETURN(true);
}
}
}

// TODO: (Enable in a follow up)
// This cannot yet be removed since smem initialization needs to be
// handled specially, e.g. as in smem_reduce test. Will be able to
// lift this one once the generic pred removal pass with fusion
// traversal is ready.
if (ir_utils::isReductionOp(consumer->definition()) &&
std::ranges::any_of(expr->inputs(), [](auto val) {
return isSharedMemoryTensor(val);
})) {
RECORD_AND_RETURN(true);
}
}

// Disable shared memory producers that is a consumer
// of another shared memory tensor. The initialization would
// break potential opportunity to re-use shared mem buffer.
//
// TODO: This is directed WAR on FusionPersistentNormLocalShared.
// This use case along with other previous issues motivate a
// joint optimization of predicate removal and buffer reuse.
//
// This use case along with other previous issues motivate a
// joint optimization of predicate removal and buffer reuse.
// In this particular case:
// __shared__ T0 [10], T1[10]
// for i in ...
Expand All @@ -156,77 +210,28 @@ bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) {
// T2 = 0; // init for exp1
// if(pred)
// T2 = T1 ... // exp1
// If we remove pred around expr1, as the way the pred removal
// pass is set up, the init for expr will be pushed up to
// initialize T1 instead.
// However if we initialize T1, the code will look like:
// If we remove pred around expr1, as the way the pred removal
// pass is set up, the init for expr will be pushed up to
// initialize T1 instead.
// However if we initialize T1, the code will look like:
// for i in ...
// T1[i] = 0;
// for i in ...
// if(pred)
// T1[i] = T0[i] + ...
// Note that we'd be able to reuse buffer of T0 for T1 but
// if we initialze T1 we cannot do that and thus the
// kernel would not fit in smaller devices.
if (producer->getMemoryType() == MemoryType::Shared) {
if (auto producer_def = producer->definition()) {
if (std::any_of(
producer_def->inputs().begin(),
producer_def->inputs().end(),
[](Val* val) {
if (auto tv = ir_utils::getTv(val)) {
return tv->getMemoryType() == MemoryType::Shared;
}
return false;
})) {
// Disable shared memory producers that is a consumer
// of another shared memory tensor. The initialization would
// break potential opportunity to re-use shared mem buffer.
return true;
}
// Note that we'd be able to reuse buffer of T0 for T1 but
// if we initialize T1 we cannot do that and thus the
// kernel would not fit in smaller devices.
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
if (auto producer_def = producer->definition(); producer_def != nullptr &&
isSharedMemoryTensor(producer) &&
std::any_of(producer_def->inputs().begin(),
producer_def->inputs().end(),
[](Val* val) { return isSharedMemoryTensor(val); })) {
RECORD_AND_RETURN(true);
}
}

for (auto id : consumer->getLoopDomain()) {
// TODO: (Enable in a follow up)
// smem predicate removal with init would break unroll and unswitch,
// eg. as in issue 1133, so disabling this removal pattern for now.
if (id->getParallelType() == ParallelType::Unroll ||
id->getParallelType() == ParallelType::Unswitch) {
return true;
}
}

// TODO: (Enable in a follow up)
// This cannot yet be removed since smem initialization needs to be
// handled specially, e.g. as in smem_reduce test. Will be able to
// lift this one once the generic pred removal pass with fusion
// traversal is ready.
auto consumer_def = consumer->definition();
if (ir_utils::isReductionOp(consumer_def)) {
if (producer->getMemoryType() == MemoryType::Shared) {
return true;
}
}

return false;
}

bool needsPredicateSharedMemAccess(const Expr* expr) {
DEBUG_PRINT_SCOPE(expr);
// This is initial step to gradually remove predicates around
// sharedmem access in suitable situations.
// Using an additional variable to track the predicate-on reasons
// when the predicate around shared mem cannot be removed.
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) {
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) {
if (isSharedMemory(producer) || isSharedMemory(consumer)) {
if (needSharedMemPredicate(producer, consumer)) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is not done when there's no producer.

RECORD_AND_RETURN(true);
}
}
}
}
RECORD_AND_RETURN(false);
}

Expand Down
36 changes: 36 additions & 0 deletions tests/cpp/test_predicate_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,4 +481,40 @@ TEST_F(PredicateEliminationTest, ExtentEqualToMaxParallelTypeExtent) {
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

TEST_F(PredicateEliminationTest, FullSharedMem) {
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({128});
fusion.addInput(tv0);

// Force launching 128 threads at minimum
auto tv1 = add(tv0, fusion.oneVal());
fusion.addOutput(tv1);

auto tv2 = zeros({IrBuilder::create<Val>(8)}, DataType::Float);
auto tv3 = add(tv2, fusion.oneVal());
fusion.addOutput(tv3);

tv2->setMemoryType(MemoryType::Shared);

for (auto tv : fusion.allTvs()) {
tv->axis(0)->parallelize(ParallelType::TIDx);
}

GpuLower gpulw(&fusion);
gpulw.run();
// tv2 expectation: should be predicated even though there's no
// producer tensor
EXPECT_TRUE(PredicatedChecker::isPredicated(tv2, gpulw));

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({128}, options);
KernelExecutor ke;
ke.compile(&fusion, {t0});
auto outputs = ke.run({t0});
testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
}

} // namespace nvfuser