-
Notifications
You must be signed in to change notification settings - Fork 67
Fix predicate elimination with shared memory tensors #5107
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9ad263e
bf79dc0
61dfadf
5ea379b
11033b4
0ecd21e
b8002f3
20c79ac
34007ea
40f3340
eaefb5a
5eb6ad9
c51ad63
940fcb5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,9 +28,11 @@ namespace { | |
// tensor memory as special type of shared memory. In this file, we use | ||
// the term "shared memory", "smem" to refer to both shared and tensor | ||
// memories. | ||
bool isSharedMemory(TensorView* tv) { | ||
return tv->getMemoryType() == MemoryType::Shared || | ||
tv->getMemoryType() == MemoryType::Tensor; | ||
bool isSharedMemoryTensor(Val* val) { | ||
auto tv = dynamic_cast<TensorView*>(val); | ||
return tv != nullptr && | ||
(tv->getMemoryType() == MemoryType::Shared || | ||
tv->getMemoryType() == MemoryType::Tensor); | ||
} | ||
|
||
// Warp primitives are currently limited to un-predicated usage, | ||
|
@@ -133,21 +135,73 @@ bool isExactParallelSharedMemAccess(TensorView* tv) { | |
return true; | ||
} | ||
|
||
// Check for conditions where the predicate cannot be removed | ||
// when either producer or consumer is in shared memory. | ||
bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) { | ||
// Indexing is based on consumer loop ids so check the consumer. | ||
bool needsPredicateSharedMemAccess(const Expr* expr) { | ||
DEBUG_PRINT_SCOPE(expr); | ||
|
||
// If consumer schedule contains in-exact thread parallel | ||
// dimensions, need to predicate against out of bound | ||
// shared memory access by out of bound threads. | ||
if (!isExactParallelSharedMemAccess(consumer)) { | ||
return true; | ||
// This is initial step to gradually remove predicates around | ||
// sharedmem access in suitable situations. | ||
// Using an additional variable to track the predicate-on reasons | ||
// when the predicate around shared mem cannot be removed. | ||
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { | ||
// If consumer schedule contains in-exact thread parallel | ||
// dimensions, need to predicate against out of bound | ||
// shared memory access by out of bound threads. | ||
// | ||
// TODO: This condition should not need | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, I thought we shouldn't need to look at producers, but that resulted in some unexpected changes that includes missed shared memory aliasing. |
||
// std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but | ||
// it's there to keep the existing behavior as of PR #5107. | ||
// Specifically, if not used, | ||
// HopperMatmulTest.EpilogueBiasPersistentBroadcastInputs fails | ||
// due to insufficient shared memory, which happens because T12 is | ||
// not aliased with T9 if the predicate of T6 is removed. This | ||
// seems to rely on an unintended effect. In this particular case, | ||
// T6 is an output of a TMA store, so its input, T9, should not | ||
// need to be initialized, and thus I think we could remove the | ||
// predicate but still allow aliasing. | ||
Comment on lines
+153
to
+160
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder why the predicate for T6 interferes with aliasing of T9 and T12. Is it because it changes the loop structure? I would've thought the predicate would just be on the TMA store i.e. at the innermost scope so that we'd still be able to do an alias. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When eliminating a predicate of an expression, we also make sure the input tensors are at least initialized. This may not be necessary as the result of non-predicated execution even if using uninitialized values should not affect anything in the final result on global memory, but I'm not sure if it's considered legal to, for example, use an uninitialized shared memory value. So, suppose we keep initializing them, then this could interfere with the aliasing logic. In this particular case, looks like what happened was when the predicate of T6 was removed, an initialization of T9 was inserted that overlaps with the lifetime of T12, which prevented the aliasing between T12 and T9. It seems that's also described here as well: Looks like that part was added in #3098. In this particular case, since the expression of T6 is a TMA store, we should be able to omit the initialization of T9 safely because the hardware takes care of predicating out-of-bounds accesses, which should allow omitting the predicate while still making sure T9 and T12 are aliased. But in general, I feel that predicate elimination should be aware of aliasing, at least for shared memory because the capacity of shared memory is likely more important factor than omitting predicates. |
||
if (isSharedMemoryTensor(consumer) || | ||
std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) { | ||
if (!isExactParallelSharedMemAccess(consumer)) { | ||
RECORD_AND_RETURN(true); | ||
} | ||
} | ||
|
||
// TODO: This condition should not need | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly here, this condition is to keep the current behavior for now. |
||
// std::ranges::any_of(expr->inputs(), isSharedMemoryTensor), but | ||
// it's there to keep the existing behavior as of PR #5107. | ||
if (isSharedMemoryTensor(consumer) || | ||
std::ranges::any_of(expr->inputs(), isSharedMemoryTensor)) { | ||
for (auto id : consumer->getLoopDomain()) { | ||
// TODO: (Enable in a follow up) | ||
// smem predicate removal with init would break unroll and unswitch, | ||
// eg. as in issue 1133, so disabling this removal pattern for now. | ||
if (id->getParallelType() == ParallelType::Unroll || | ||
id->getParallelType() == ParallelType::Unswitch) { | ||
RECORD_AND_RETURN(true); | ||
} | ||
} | ||
} | ||
|
||
// TODO: (Enable in a follow up) | ||
// This cannot yet be removed since smem initialization needs to be | ||
// handled specially, e.g. as in smem_reduce test. Will be able to | ||
// lift this one once the generic pred removal pass with fusion | ||
// traversal is ready. | ||
if (ir_utils::isReductionOp(consumer->definition()) && | ||
std::ranges::any_of(expr->inputs(), [](auto val) { | ||
return isSharedMemoryTensor(val); | ||
})) { | ||
RECORD_AND_RETURN(true); | ||
} | ||
} | ||
|
||
// Disable shared memory producers that is a consumer | ||
// of another shared memory tensor. The initialization would | ||
// break potential opportunity to re-use shared mem buffer. | ||
// | ||
// TODO: This is directed WAR on FusionPersistentNormLocalShared. | ||
// This use case along with other previous issues motivate a | ||
// joint optimization of predicate removal and buffer reuse. | ||
// | ||
// This use case along with other previous issues motivate a | ||
// joint optimization of predicate removal and buffer reuse. | ||
// In this particular case: | ||
// __shared__ T0 [10], T1[10] | ||
// for i in ... | ||
|
@@ -156,77 +210,28 @@ bool needSharedMemPredicate(TensorView* producer, TensorView* consumer) { | |
// T2 = 0; // init for exp1 | ||
// if(pred) | ||
// T2 = T1 ... // exp1 | ||
// If we remove pred around expr1, as the way the pred removal | ||
// pass is set up, the init for expr will be pushed up to | ||
// initialize T1 instead. | ||
// However if we initialize T1, the code will look like: | ||
// If we remove pred around expr1, as the way the pred removal | ||
// pass is set up, the init for expr will be pushed up to | ||
// initialize T1 instead. | ||
// However if we initialize T1, the code will look like: | ||
// for i in ... | ||
// T1[i] = 0; | ||
// for i in ... | ||
// if(pred) | ||
// T1[i] = T0[i] + ... | ||
// Note that we'd be able to reuse buffer of T0 for T1 but | ||
// if we initialze T1 we cannot do that and thus the | ||
// kernel would not fit in smaller devices. | ||
if (producer->getMemoryType() == MemoryType::Shared) { | ||
if (auto producer_def = producer->definition()) { | ||
if (std::any_of( | ||
producer_def->inputs().begin(), | ||
producer_def->inputs().end(), | ||
[](Val* val) { | ||
if (auto tv = ir_utils::getTv(val)) { | ||
return tv->getMemoryType() == MemoryType::Shared; | ||
} | ||
return false; | ||
})) { | ||
// Disable shared memory producers that is a consumer | ||
// of another shared memory tensor. The initialization would | ||
// break potential opportunity to re-use shared mem buffer. | ||
return true; | ||
} | ||
// Note that we'd be able to reuse buffer of T0 for T1 but | ||
// if we initialize T1 we cannot do that and thus the | ||
// kernel would not fit in smaller devices. | ||
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { | ||
if (auto producer_def = producer->definition(); producer_def != nullptr && | ||
isSharedMemoryTensor(producer) && | ||
std::any_of(producer_def->inputs().begin(), | ||
producer_def->inputs().end(), | ||
[](Val* val) { return isSharedMemoryTensor(val); })) { | ||
RECORD_AND_RETURN(true); | ||
} | ||
} | ||
|
||
for (auto id : consumer->getLoopDomain()) { | ||
// TODO: (Enable in a follow up) | ||
// smem predicate removal with init would break unroll and unswitch, | ||
// eg. as in issue 1133, so disabling this removal pattern for now. | ||
if (id->getParallelType() == ParallelType::Unroll || | ||
id->getParallelType() == ParallelType::Unswitch) { | ||
return true; | ||
} | ||
} | ||
|
||
// TODO: (Enable in a follow up) | ||
// This cannot yet be removed since smem initialization needs to be | ||
// handled specially, e.g. as in smem_reduce test. Will be able to | ||
// lift this one once the generic pred removal pass with fusion | ||
// traversal is ready. | ||
auto consumer_def = consumer->definition(); | ||
if (ir_utils::isReductionOp(consumer_def)) { | ||
if (producer->getMemoryType() == MemoryType::Shared) { | ||
return true; | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
|
||
bool needsPredicateSharedMemAccess(const Expr* expr) { | ||
DEBUG_PRINT_SCOPE(expr); | ||
// This is initial step to gradually remove predicates around | ||
// sharedmem access in suitable situations. | ||
// Using an additional variable to track the predicate-on reasons | ||
// when the predicate around shared mem cannot be removed. | ||
for (auto consumer : ir_utils::filterByType<TensorView>(expr->outputs())) { | ||
for (auto producer : ir_utils::filterByType<TensorView>(expr->inputs())) { | ||
if (isSharedMemory(producer) || isSharedMemory(consumer)) { | ||
if (needSharedMemPredicate(producer, consumer)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check is not done when there's no producer. |
||
RECORD_AND_RETURN(true); | ||
} | ||
} | ||
} | ||
} | ||
RECORD_AND_RETURN(false); | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes here are just making sure all checks are performed even when there's no producer.It turned out this resulted in a lot of code changes, which I don't feel like going through right now. For now, I tried to keep the current behavior as is, even though in some cases it doesn't seem to make sense, as long as the correctness is preserved.