diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 0af79e9ad85..9b365c9d59a 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -28,9 +28,11 @@ namespace { // tensor memory as special type of shared memory. In this file, we use // the term "shared memory", "smem" to refer to both shared and tensor // memories. -bool isSharedMemory(TensorView* tv) { - return tv->getMemoryType() == MemoryType::Shared || - tv->getMemoryType() == MemoryType::Tensor; +bool isSharedMemoryTensor(Val* val) { + auto tv = dynamic_cast(val); + return tv != nullptr && + (tv->getMemoryType() == MemoryType::Shared || + tv->getMemoryType() == MemoryType::Tensor); } // Warp primitives are currently limited to un-predicated usage, @@ -133,21 +135,73 @@ bool isExactParallelSharedMemAccess(TensorView* tv) { return true; } -// Check for conditions where the predicate cannot be removed -// when either producer or consumer is in shared memory. -bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) { - // Indexing is based on consumer loop ids so check the consumer. +bool needsPredicateSharedMemAccess(const Expr* expr) { + DEBUG_PRINT_SCOPE(expr); - // If consumer schedule contains in-exact thread parallel - // dimensions, need to predicate against out of bound - // shared memory access by out of bound threads. - if (!isExactParallelSharedMemAccess(consumer)) { - return true; + // This is initial step to gradually remove predicates around + // sharedmem access in suitable situations. + // Using an additional variable to track the predicate-on reasons + // when the predicate around shared mem cannot be removed. + for (auto consumer : ir_utils::filterByType(expr->outputs())) { + // If consumer schedule contains in-exact thread parallel + // dimensions, need to predicate against out of bound + // shared memory access by out of bound threads. + // + // TODO: This condition should not need + // std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but + // it's there to keep the existing behavior as of PR #5107. + // Specifically, if not used, + // HopperMatmulTest.EpilogueBiasPersistentBroadcastInputs fails + // due to insufficient shared memory, which happens because T12 is + // not aliased with T9 if the predicate of T6 is removed. This + // seems to rely on an unintended effect. In this particular case, + // T6 is an output of a TMA store, so its input, T9, should not + // need to be initialized, and thus I think we could remove the + // predicate but still allow aliasing. + if (isSharedMemoryTensor(consumer) || + std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) { + if (!isExactParallelSharedMemAccess(consumer)) { + RECORD_AND_RETURN(true); + } + } + + // TODO: This condition should not need + // std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but + // it's there to keep the existing behavior as of PR #5107. + if (isSharedMemoryTensor(consumer) || + std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) { + for (auto id : consumer->getLoopDomain()) { + // TODO: (Enable in a follow up) + // smem predicate removal with init would break unroll and unswitch, + // eg. as in issue 1133, so disabling this removal pattern for now. + if (id->getParallelType() == ParallelType::Unroll || + id->getParallelType() == ParallelType::Unswitch) { + RECORD_AND_RETURN(true); + } + } + } + + // TODO: (Enable in a follow up) + // This cannot yet be removed since smem initialization needs to be + // handled specially, e.g. as in smem_reduce test. Will be able to + // lift this one once the generic pred removal pass with fusion + // traversal is ready. + if (ir_utils::isReductionOp(consumer->definition()) && + std::ranges::any_of(expr->inputs(), [](auto val) { + return isSharedMemoryTensor(val); + })) { + RECORD_AND_RETURN(true); + } } + // Disable shared memory producers that is a consumer + // of another shared memory tensor. The initialization would + // break potential opportunity to re-use shared mem buffer. + // // TODO: This is directed WAR on FusionPersistentNormLocalShared. - // This use case along with other previous issues motivate a - // joint optimization of predicate removal and buffer reuse. + // + // This use case along with other previous issues motivate a + // joint optimization of predicate removal and buffer reuse. // In this particular case: // __shared__ T0 [10], T1[10] // for i in ... @@ -156,77 +210,28 @@ bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) { // T2 = 0; // init for exp1 // if(pred) // T2 = T1 ... // exp1 - // If we remove pred around expr1, as the way the pred removal - // pass is set up, the init for expr will be pushed up to - // initialize T1 instead. - // However if we initialize T1, the code will look like: + // If we remove pred around expr1, as the way the pred removal + // pass is set up, the init for expr will be pushed up to + // initialize T1 instead. + // However if we initialize T1, the code will look like: // for i in ... // T1[i] = 0; // for i in ... // if(pred) // T1[i] = T0[i] + ... - // Note that we'd be able to reuse buffer of T0 for T1 but - // if we initialze T1 we cannot do that and thus the - // kernel would not fit in smaller devices. - if (producer->getMemoryType() == MemoryType::Shared) { - if (auto producer_def = producer->definition()) { - if (std::any_of( - producer_def->inputs().begin(), - producer_def->inputs().end(), - [](Val* val) { - if (auto tv = ir_utils::getTv(val)) { - return tv->getMemoryType() == MemoryType::Shared; - } - return false; - })) { - // Disable shared memory producers that is a consumer - // of another shared memory tensor. The initialization would - // break potential opportunity to re-use shared mem buffer. - return true; - } + // Note that we'd be able to reuse buffer of T0 for T1 but + // if we initialize T1 we cannot do that and thus the + // kernel would not fit in smaller devices. + for (auto producer : ir_utils::filterByType(expr->inputs())) { + if (auto producer_def = producer->definition(); producer_def != nullptr && + isSharedMemoryTensor(producer) && + std::any_of(producer_def->inputs().begin(), + producer_def->inputs().end(), + [](Val* val) { return isSharedMemoryTensor(val); })) { + RECORD_AND_RETURN(true); } } - for (auto id : consumer->getLoopDomain()) { - // TODO: (Enable in a follow up) - // smem predicate removal with init would break unroll and unswitch, - // eg. as in issue 1133, so disabling this removal pattern for now. - if (id->getParallelType() == ParallelType::Unroll || - id->getParallelType() == ParallelType::Unswitch) { - return true; - } - } - - // TODO: (Enable in a follow up) - // This cannot yet be removed since smem initialization needs to be - // handled specially, e.g. as in smem_reduce test. Will be able to - // lift this one once the generic pred removal pass with fusion - // traversal is ready. - auto consumer_def = consumer->definition(); - if (ir_utils::isReductionOp(consumer_def)) { - if (producer->getMemoryType() == MemoryType::Shared) { - return true; - } - } - - return false; -} - -bool needsPredicateSharedMemAccess(const Expr* expr) { - DEBUG_PRINT_SCOPE(expr); - // This is initial step to gradually remove predicates around - // sharedmem access in suitable situations. - // Using an additional variable to track the predicate-on reasons - // when the predicate around shared mem cannot be removed. - for (auto consumer : ir_utils::filterByType(expr->outputs())) { - for (auto producer : ir_utils::filterByType(expr->inputs())) { - if (isSharedMemory(producer) || isSharedMemory(consumer)) { - if (needSharedMemPredicate(producer, consumer)) { - RECORD_AND_RETURN(true); - } - } - } - } RECORD_AND_RETURN(false); } diff --git a/tests/cpp/test_predicate_elimination.cpp b/tests/cpp/test_predicate_elimination.cpp index 936cd95c22b..9d8995a381d 100644 --- a/tests/cpp/test_predicate_elimination.cpp +++ b/tests/cpp/test_predicate_elimination.cpp @@ -481,4 +481,40 @@ TEST_F(PredicateEliminationTest, ExtentEqualToMaxParallelTypeExtent) { testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); } +TEST_F(PredicateEliminationTest, FullSharedMem) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({128}); + fusion.addInput(tv0); + + // Force launching 128 threads at minimum + auto tv1 = add(tv0, fusion.oneVal()); + fusion.addOutput(tv1); + + auto tv2 = zeros({IrBuilder::create(8)}, DataType::Float); + auto tv3 = add(tv2, fusion.oneVal()); + fusion.addOutput(tv3); + + tv2->setMemoryType(MemoryType::Shared); + + for (auto tv : fusion.allTvs()) { + tv->axis(0)->parallelize(ParallelType::TIDx); + } + + GpuLower gpulw(&fusion); + gpulw.run(); + // tv2 expectation: should be predicated even though there's no + // producer tensor + EXPECT_TRUE(PredicatedChecker::isPredicated(tv2, gpulw)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({128}, options); + KernelExecutor ke; + ke.compile(&fusion, {t0}); + auto outputs = ke.run({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); +} + } // namespace nvfuser