Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7f634f5
WIP
naoyam Jul 7, 2025
c654486
cleanup
naoyam Jul 7, 2025
8ef8e3e
cleanup
naoyam Jul 7, 2025
549e407
enable codegen of argsort+scatter
naoyam Jul 7, 2025
ea81050
Use IterDomain::merge instead of manually creating a Merge
naoyam Jul 8, 2025
b32ccf3
Merge branch 'main' into simplify_flatten
naoyam Jul 8, 2025
8003a94
Convert indexPutAccumulate to scatter when possible
naoyam Jul 8, 2025
21054f0
Merge remote-tracking branch 'origin/simplify_flatten' into scatter
naoyam Jul 8, 2025
2a5379c
enable codegen of compute_problem_sizes
naoyam Jul 8, 2025
bc6020d
remove old test
naoyam Jul 8, 2025
d2d127b
scatter with shmem
naoyam Jul 8, 2025
6a8c74e
cleanup
naoyam Jul 9, 2025
dce9246
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 9, 2025
f725ad9
cleanup
naoyam Jul 9, 2025
96bb1d0
cleanup
naoyam Jul 9, 2025
0fc4aff
cleanup
naoyam Jul 9, 2025
d8291fc
fix
naoyam Jul 9, 2025
59d73b2
cleanup
naoyam Jul 9, 2025
09617ea
fix
naoyam Jul 9, 2025
2fc5184
test fix
naoyam Jul 9, 2025
ee36099
Moved the change of the loop domain to a scheduling routine
naoyam Jul 9, 2025
fd2b83b
bug fix
naoyam Jul 10, 2025
6cd1c3b
cleanup
naoyam Jul 10, 2025
504b3fe
Merge branch 'main' into scatter
naoyam Jul 25, 2025
708db3d
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 28, 2025
537bced
WIP
naoyam Jul 28, 2025
4193f9b
simplify
naoyam Jul 29, 2025
8f15f70
cleanup
naoyam Jul 29, 2025
02ae364
cleanup
naoyam Jul 29, 2025
fd18ff3
cleanup
naoyam Jul 29, 2025
b63a86f
IdModel test
naoyam Jul 29, 2025
25f1b31
cleanup
naoyam Jul 30, 2025
fa9b895
cleanup
naoyam Jul 30, 2025
f08d4ed
update
naoyam Jul 30, 2025
8823a4c
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 30, 2025
e293562
merge fix
naoyam Jul 30, 2025
16c359f
cleanup
naoyam Aug 7, 2025
7d0cc96
Merge branch 'main' into scatter
naoyam Aug 7, 2025
650bf85
interface change
naoyam Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/device_lower/pass/grid_serialization.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/index.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/inline_ptx.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/inplace_alias.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/insert_syncs.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/instrument.cpp
${NVFUSER_SRCS_DIR}/device_lower/pass/loop_rotation.cpp
Expand Down
7 changes: 6 additions & 1 deletion csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
if (sop->getScatterOpType() == ScatterOpType::Set) {
// When value of index_tv are not unique, the behavior of Set is
// non-deterministic
indent() << gen(sop->output(0)) << " = " << gen(sop->input(2)) << ";\n";
indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n";
} else {
NVF_THROW("unkown scatterOp");
}
Expand Down Expand Up @@ -3674,6 +3674,11 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
const auto buffer_dtype = alloc->buffer()->dtype();

NVF_ERROR(alloc->buffer() != nullptr);

if (alloc->buffer()->isFusionOutput()) {
return;
}

alloc_set_.emplace(alloc->buffer());

if (!alloc->buffer()->isA<TensorView>()) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <device_lower/pass/grid_serialization.h>
#include <device_lower/pass/index.h>
#include <device_lower/pass/inline_ptx.h>
#include <device_lower/pass/inplace_alias.h>
#include <device_lower/pass/insert_syncs.h>
#include <device_lower/pass/instrument.h>
#include <device_lower/pass/loop_rotation.h>
Expand Down Expand Up @@ -270,6 +271,7 @@ GpuLower::GpuLower(Fusion* fusion, const CompileParams& cparams)
{"loadStoreOpInserter", loadStoreOpInserter},
{"insertGridSerializationSyncs", insertGridSerializationSyncs},
{"insertAllocations", insertAllocations},
{"setInplaceAlias", setInplaceAlias},
{"reuseMemoryAllocations", reuseMemoryAllocations},
{"CircularBufferPass", CircularBufferPass::run},
{"insertRawThreadSynchronization", insertRawThreadSynchronization},
Expand Down Expand Up @@ -593,6 +595,9 @@ void GpuLower::analysis(Fusion* fusion) {
validateLookupTV(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "validateLookupTV");

validateScatter(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "validateScatter");

// Find trivial global to global broadcast, squeeze, and set operations and
// mark their outputs as aliases of their inputs.
findTensorProducerAliases(fusion_);
Expand Down
24 changes: 11 additions & 13 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ class AllocationDomainSetup : private kir::IrVisitor {
})) {
return true;
}

// If a shared memory output produced by scatter has an
// allocation domain explicitly set, it's likely to be the
// valid allocation domain.
if (auto def = tv->definition();
def != nullptr && def->isA<ScatterOp>()) {
return true;
}
}
return false;
}
Expand Down Expand Up @@ -1172,10 +1180,8 @@ class AllocationInserter : public kir::ExprMutator {
return init_expr;
}

kir::Allocate* createAllocExpr(AllocationInformation& info, bool is_output) {
if (is_output) {
return nullptr;
}
kir::Allocate* createAllocExpr(AllocationInformation& info) {
// Note that Allocate nodes are created for fusion outputs too

TensorView* tv_to_alloc = info.buffer;
const MemoryType memory_type = tv_to_alloc->getMemoryType();
Expand Down Expand Up @@ -1343,19 +1349,11 @@ class AllocationInserter : public kir::ExprMutator {
init = nullptr;
}

const bool is_output = out->isFusionOutput();

// Don't need to alloc outputs, and if we don't need to initialize we're
// done.
if (is_output && init == nullptr) {
continue;
}

AllocationInformation allocation;
allocation.buffer = out_tv;
fillAllocationInformation(allocation, expr);

auto alloc_expr = createAllocExpr(allocation, is_output);
auto alloc_expr = createAllocExpr(allocation);
auto init_expr = createInitExpr(allocation, init);

// Check that all circular buffer depth match
Expand Down
19 changes: 10 additions & 9 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Val* IndexLowering::lowerSrcIndex(

Val* IndexLowering::lowerDstIndex(
Val* dst,
const std::unordered_map<int, Val*>& override_index,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool generate_pointer,
DataType as_type) const {
if (auto tv = dynamic_cast<TensorView*>(dst)) {
Expand Down Expand Up @@ -356,19 +356,20 @@ void IndexLowering::handle(const GatherOp* top) {
}

void IndexLowering::handle(const ScatterOp* sop) {
auto lowered_index = lowerSrcIndex(sop->indexTv(), sop->output(0));
auto lowered_src = lowerSrcIndex(sop->srcTv(), sop->output(0));
// At this point, out and self are aliased, so they can be used
// interchangeably.

lowered_index = IrBuilder::maybeCastExpr(DataType::Index, lowered_index);
auto lowered_index = lowerSrcIndex(sop->index(), sop->out());
auto lowered_src = lowerSrcIndex(sop->src(), sop->out());

const std::unordered_map<int, Val*> override_index_out = {
{sop->dim(), lowered_index}};
auto lowered_out = lowerDstIndex(sop->output(0), override_index_out);
const std::unordered_map<IterDomain*, Val*> override_index = {
{sop->getIndexedID(), lowered_index}};
auto lowered_out = lowerDstIndex(sop->out(), override_index);

pushBack(IrBuilder::create<ScatterOp>(
sop->getScatterOpType(),
lowered_out,
sop->selfTv(),
/*out=*/lowered_out,
/*self=*/lowered_out,
sop->dim(),
lowered_index,
lowered_src));
Expand Down
2 changes: 1 addition & 1 deletion csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class IndexLowering : private OptOutConstDispatch {

Val* lowerDstIndex(
Val* dst,
const std::unordered_map<int, Val*>& override_index = {},
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool generate_pointer = false,
DataType as_type = DataType::Null) const;

Expand Down
137 changes: 137 additions & 0 deletions csrc/device_lower/pass/inplace_alias.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <device_lower/lower2device.h>
#include <device_lower/pass/inplace_alias.h>
#include <device_lower/utils.h>
#include <ir/builder.h>
#include <ir/utils.h>
#include <kernel_ir.h>
#include <kernel_ir_dispatch.h>

namespace nvfuser {

namespace {

// Properties gathered from a given Kernel
struct InplaceAliasInfo {
// Map to find the Allocate node for each tensor
std::unordered_map<TensorView*, kir::Allocate*> alloc_map;
// Disjoint sets to group aliased tensors
DisjointSets<TensorView*> aliased_tvs;
// Unique tensor for each tensor alias groups representing its real allocation
std::unordered_map<DisjointSets<TensorView*>::DisjointSet, TensorView*>
real_alloc_map;
};

class InplaceAliasInfoBuilder : public kir::IrVisitor {
public:
const InplaceAliasInfo& info() {
return info_;
}

using IrVisitor::handle;

void handle(ScatterOp* sop) final {
auto in_tv = sop->in()->as<TensorView>();
auto out_tv = sop->out()->as<TensorView>();

// Note that in_tv and out_tv are already validated to be safe to
// alias each other by validateScatter

NVF_ERROR(
info_.alloc_map.find(in_tv) != info_.alloc_map.end(),
"No allocation mapping found for scatter input: ",
in_tv->toString());
NVF_ERROR(
info_.alloc_map.find(out_tv) != info_.alloc_map.end(),
"No allocation mapping found for scatter output: ",
out_tv->toString());

auto in_tv_alias_it = info_.aliased_tvs.find(in_tv);
if (in_tv_alias_it != info_.aliased_tvs.end()) {
info_.aliased_tvs.appendToSet(out_tv, in_tv_alias_it->second);
} else {
info_.aliased_tvs.mapEntries(in_tv, out_tv);
in_tv_alias_it = info_.aliased_tvs.find(in_tv);
// Pick the input as the actual allocation of this tensor group
info_.real_alloc_map.emplace(in_tv_alias_it->second, in_tv);
}

// If the output is also a fusion output, use it as the real
// allocation.
if (out_tv->isFusionOutput()) {
info_.real_alloc_map[in_tv_alias_it->second] = out_tv;
}
}

void handle(kir::Allocate* alloc) final {
// Keep track of tensor allocations. Do not bother if already
// aliasing another
if (auto alloc_tv = dynamic_cast<TensorView*>(alloc->buffer());
alloc_tv != nullptr && alloc->alias() == nullptr) {
NVF_ERROR(info_.alloc_map.emplace(alloc_tv, alloc).second);
}
}

private:
InplaceAliasInfo info_;
};

class InplaceAliasMutator : public kir::ExprMutator {
public:
InplaceAliasMutator(const InplaceAliasInfo& info) : info_(info) {}

protected:
using ExprMutator::handle;

void handle(kir::Allocate* alloc) final {
auto tv = dynamic_cast<TensorView*>(alloc->buffer());
if (tv == nullptr) {
// Ignore non-tensor allocation
return;
}

auto alias_it = info_.aliased_tvs.find(tv);
if (alias_it == info_.aliased_tvs.end()) {
// Not aliased
return;
}

auto real_alloc_tv = info_.real_alloc_map.at(alias_it->second);
if (tv == real_alloc_tv) {
// This tensor is the actual allocation
return;
}

auto real_alloc = info_.alloc_map.at(real_alloc_tv);

auto new_alloc = IrBuilder::create<kir::Allocate>(
alloc->buffer(),
alloc->memoryType(),
alloc->shape(),
alloc->zeroInit(),
/*reset_to_zero=*/false,
real_alloc);

registerReplace(alloc, new_alloc);
}

private:
const InplaceAliasInfo& info_;
};

} // namespace

std::vector<Expr*> setInplaceAlias(const std::vector<Expr*>& exprs) {
InplaceAliasInfoBuilder builder;
builder.handle(exprs);
return InplaceAliasMutator(builder.info()).traverseAndInsert(exprs);
}

} // namespace nvfuser
32 changes: 32 additions & 0 deletions csrc/device_lower/pass/inplace_alias.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

#include <vector>

namespace nvfuser {

class Expr;

// This lowering pass creates aliasing relationships between
// kir::Allocate nodes. Currently, this is only designed for the
// inplace scatter op.
//
// More specifically, it first gathers all the information by scanning
// the given Kernel, including mappings of tensors to their kir::Allocate
// nodes, grouping of tensors that alias each other. It is used then
// to mutate Kernel such that all tensors in the same alias group
// points to the same actual allocation. This is done by updating
// kir::Allocate's alias attribute.
//
// The selection of the actual allocation for each tensor group is
// done by just picking the first tensor, except when the group has a
// fusion output, in which case the output is selected.
std::vector<Expr*> setInplaceAlias(const std::vector<Expr*>& exprs);

} // namespace nvfuser
Loading