-
Notifications
You must be signed in to change notification settings - Fork 69
Lowering scatter #4742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Lowering scatter #4742
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
7f634f5
WIP
naoyam c654486
cleanup
naoyam 8ef8e3e
cleanup
naoyam 549e407
enable codegen of argsort+scatter
naoyam ea81050
Use IterDomain::merge instead of manually creating a Merge
naoyam b32ccf3
Merge branch 'main' into simplify_flatten
naoyam 8003a94
Convert indexPutAccumulate to scatter when possible
naoyam 21054f0
Merge remote-tracking branch 'origin/simplify_flatten' into scatter
naoyam 2a5379c
enable codegen of compute_problem_sizes
naoyam bc6020d
remove old test
naoyam d2d127b
scatter with shmem
naoyam 6a8c74e
cleanup
naoyam dce9246
Merge remote-tracking branch 'origin/main' into scatter
naoyam f725ad9
cleanup
naoyam 96bb1d0
cleanup
naoyam 0fc4aff
cleanup
naoyam d8291fc
fix
naoyam 59d73b2
cleanup
naoyam 09617ea
fix
naoyam 2fc5184
test fix
naoyam ee36099
Moved the change of the loop domain to a scheduling routine
naoyam fd2b83b
bug fix
naoyam 6cd1c3b
cleanup
naoyam 504b3fe
Merge branch 'main' into scatter
naoyam 708db3d
Merge remote-tracking branch 'origin/main' into scatter
naoyam 537bced
WIP
naoyam 4193f9b
simplify
naoyam 8f15f70
cleanup
naoyam 02ae364
cleanup
naoyam fd18ff3
cleanup
naoyam b63a86f
IdModel test
naoyam 25f1b31
cleanup
naoyam fa9b895
cleanup
naoyam f08d4ed
update
naoyam 8823a4c
Merge remote-tracking branch 'origin/main' into scatter
naoyam e293562
merge fix
naoyam 16c359f
cleanup
naoyam 7d0cc96
Merge branch 'main' into scatter
naoyam 650bf85
interface change
naoyam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
|
||
#include <device_lower/lower2device.h> | ||
#include <device_lower/pass/inplace_alias.h> | ||
#include <device_lower/utils.h> | ||
#include <ir/builder.h> | ||
#include <ir/utils.h> | ||
#include <kernel_ir.h> | ||
#include <kernel_ir_dispatch.h> | ||
|
||
namespace nvfuser { | ||
|
||
namespace { | ||
|
||
// Properties gathered from a given Kernel | ||
struct InplaceAliasInfo { | ||
// Map to find the Allocate node for each tensor | ||
std::unordered_map<TensorView*, kir::Allocate*> alloc_map; | ||
// Disjoint sets to group aliased tensors | ||
DisjointSets<TensorView*> aliased_tvs; | ||
// Unique tensor for each tensor alias groups representing its real allocation | ||
std::unordered_map<DisjointSets<TensorView*>::DisjointSet, TensorView*> | ||
real_alloc_map; | ||
}; | ||
|
||
class InplaceAliasInfoBuilder : public kir::IrVisitor { | ||
public: | ||
const InplaceAliasInfo& info() { | ||
return info_; | ||
} | ||
|
||
using IrVisitor::handle; | ||
|
||
void handle(ScatterOp* sop) final { | ||
auto in_tv = sop->in()->as<TensorView>(); | ||
auto out_tv = sop->out()->as<TensorView>(); | ||
|
||
// Note that in_tv and out_tv are already validated to be safe to | ||
// alias each other by validateScatter | ||
|
||
NVF_ERROR( | ||
info_.alloc_map.find(in_tv) != info_.alloc_map.end(), | ||
"No allocation mapping found for scatter input: ", | ||
in_tv->toString()); | ||
NVF_ERROR( | ||
info_.alloc_map.find(out_tv) != info_.alloc_map.end(), | ||
"No allocation mapping found for scatter output: ", | ||
out_tv->toString()); | ||
|
||
auto in_tv_alias_it = info_.aliased_tvs.find(in_tv); | ||
if (in_tv_alias_it != info_.aliased_tvs.end()) { | ||
info_.aliased_tvs.appendToSet(out_tv, in_tv_alias_it->second); | ||
} else { | ||
info_.aliased_tvs.mapEntries(in_tv, out_tv); | ||
in_tv_alias_it = info_.aliased_tvs.find(in_tv); | ||
// Pick the input as the actual allocation of this tensor group | ||
info_.real_alloc_map.emplace(in_tv_alias_it->second, in_tv); | ||
jjsjann123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// If the output is also a fusion output, use it as the real | ||
// allocation. | ||
if (out_tv->isFusionOutput()) { | ||
info_.real_alloc_map[in_tv_alias_it->second] = out_tv; | ||
jjsjann123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
void handle(kir::Allocate* alloc) final { | ||
// Keep track of tensor allocations. Do not bother if already | ||
// aliasing another | ||
if (auto alloc_tv = dynamic_cast<TensorView*>(alloc->buffer()); | ||
alloc_tv != nullptr && alloc->alias() == nullptr) { | ||
NVF_ERROR(info_.alloc_map.emplace(alloc_tv, alloc).second); | ||
} | ||
} | ||
|
||
private: | ||
InplaceAliasInfo info_; | ||
}; | ||
|
||
class InplaceAliasMutator : public kir::ExprMutator { | ||
public: | ||
InplaceAliasMutator(const InplaceAliasInfo& info) : info_(info) {} | ||
|
||
protected: | ||
using ExprMutator::handle; | ||
|
||
void handle(kir::Allocate* alloc) final { | ||
auto tv = dynamic_cast<TensorView*>(alloc->buffer()); | ||
if (tv == nullptr) { | ||
// Ignore non-tensor allocation | ||
return; | ||
} | ||
|
||
auto alias_it = info_.aliased_tvs.find(tv); | ||
if (alias_it == info_.aliased_tvs.end()) { | ||
// Not aliased | ||
return; | ||
} | ||
|
||
auto real_alloc_tv = info_.real_alloc_map.at(alias_it->second); | ||
if (tv == real_alloc_tv) { | ||
// This tensor is the actual allocation | ||
return; | ||
} | ||
|
||
auto real_alloc = info_.alloc_map.at(real_alloc_tv); | ||
|
||
auto new_alloc = IrBuilder::create<kir::Allocate>( | ||
alloc->buffer(), | ||
alloc->memoryType(), | ||
alloc->shape(), | ||
alloc->zeroInit(), | ||
/*reset_to_zero=*/false, | ||
real_alloc); | ||
jjsjann123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
registerReplace(alloc, new_alloc); | ||
} | ||
|
||
private: | ||
const InplaceAliasInfo& info_; | ||
}; | ||
|
||
} // namespace | ||
|
||
std::vector<Expr*> setInplaceAlias(const std::vector<Expr*>& exprs) { | ||
InplaceAliasInfoBuilder builder; | ||
builder.handle(exprs); | ||
return InplaceAliasMutator(builder.info()).traverseAndInsert(exprs); | ||
} | ||
|
||
} // namespace nvfuser |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#pragma once | ||
|
||
#include <vector> | ||
|
||
namespace nvfuser { | ||
|
||
class Expr; | ||
|
||
// This lowering pass creates aliasing relationships between | ||
// kir::Allocate nodes. Currently, this is only designed for the | ||
// inplace scatter op. | ||
// | ||
// More specifically, it first gathers all the information by scanning | ||
// the given Kernel, including mappings of tensors to their kir::Allocate | ||
// nodes, grouping of tensors that alias each other. It is used then | ||
// to mutate Kernel such that all tensors in the same alias group | ||
// points to the same actual allocation. This is done by updating | ||
// kir::Allocate's alias attribute. | ||
// | ||
// The selection of the actual allocation for each tensor group is | ||
// done by just picking the first tensor, except when the group has a | ||
// fusion output, in which case the output is selected. | ||
std::vector<Expr*> setInplaceAlias(const std::vector<Expr*>& exprs); | ||
|
||
} // namespace nvfuser |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.