Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7f634f5
WIP
naoyam Jul 7, 2025
c654486
cleanup
naoyam Jul 7, 2025
8ef8e3e
cleanup
naoyam Jul 7, 2025
549e407
enable codegen of argsort+scatter
naoyam Jul 7, 2025
ea81050
Use IterDomain::merge instead of manually creating a Merge
naoyam Jul 8, 2025
b32ccf3
Merge branch 'main' into simplify_flatten
naoyam Jul 8, 2025
8003a94
Convert indexPutAccumulate to scatter when possible
naoyam Jul 8, 2025
21054f0
Merge remote-tracking branch 'origin/simplify_flatten' into scatter
naoyam Jul 8, 2025
2a5379c
enable codegen of compute_problem_sizes
naoyam Jul 8, 2025
bc6020d
remove old test
naoyam Jul 8, 2025
d2d127b
scatter with shmem
naoyam Jul 8, 2025
6a8c74e
cleanup
naoyam Jul 9, 2025
dce9246
Merge remote-tracking branch 'origin/main' into scatter
naoyam Jul 9, 2025
f725ad9
cleanup
naoyam Jul 9, 2025
96bb1d0
cleanup
naoyam Jul 9, 2025
0fc4aff
cleanup
naoyam Jul 9, 2025
d8291fc
fix
naoyam Jul 9, 2025
59d73b2
cleanup
naoyam Jul 9, 2025
09617ea
fix
naoyam Jul 9, 2025
2fc5184
test fix
naoyam Jul 9, 2025
ee36099
Moved the change of the loop domain to a scheduling routine
naoyam Jul 9, 2025
fd2b83b
bug fix
naoyam Jul 10, 2025
6cd1c3b
cleanup
naoyam Jul 10, 2025
4180b03
format
naoyam Jul 9, 2025
0d9ed76
WIP
naoyam Jul 9, 2025
ff6c3f3
Merge branch 'main' into scatter-accumulate
naoyam Sep 2, 2025
b4bea4a
cleanup
naoyam Sep 3, 2025
2e382ac
cleanup
naoyam Sep 3, 2025
4d8ac68
cleanup
naoyam Sep 3, 2025
d3fc769
cleanup
naoyam Sep 3, 2025
47c9d4d
fix
naoyam Sep 3, 2025
90b1217
replayce indexPutAccumualte with scatter
naoyam Sep 3, 2025
2b2f524
format
naoyam Sep 3, 2025
bbcc433
Merge remote-tracking branch 'origin/main' into scatter-accumulate
naoyam Sep 3, 2025
8c2a9e1
test update
naoyam Sep 3, 2025
3411324
Merge branch 'main' into scatter-accumulate
naoyam Sep 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,13 +1303,73 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}

void handle(const ScatterOp* sop) final {
// generate code like T_output[... T_index[...]] = op(T_src[...]);
if (sop->getScatterOpType() == ScatterOpType::Set) {
// When value of index_tv are not unique, the behavior of Set is
// non-deterministic
indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n";
} else {
NVF_THROW("unkown scatterOp");
if (sop->accumulate()) {
handleScatterAccumulate(sop);
return;
}

// Generate code like T_output[... T_index[...]] = op(T_src[...]);
//
// When value of index_tv are not unique, the behavior of Set is
// non-deterministic
indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n";
}

// Atomic-based accumulation. Only supported with integer data or
// non determinism is excplicitly permitted
void handleScatterAccumulate(const ScatterOp* sop) {
const bool non_deterministic = isFloatingPointType(sop->src()->dtype()) &&
(sop->accumulateOp() != BinaryOpType::Max ||
sop->accumulateOp() != BinaryOpType::Min);

NVF_ERROR(
!at::globalContext().deterministicAlgorithms() || !non_deterministic,
"Trying to use non-deterministic instructions even though "
"deterministic algorithm is requested: ",
sop->toString());

NVF_ERROR(
sop->src()->dtype() == DataType::Int ||
sop->src()->dtype() == DataType::Int32 ||
sop->src()->dtype() == DataType::Float ||
sop->src()->dtype() == DataType::Double,
"Data type not supported: ",
sop->src()->dtype());

const auto dst = gen(sop->out());
const auto src = gen(sop->src());

indent();

switch (sop->accumulateOp()) {
case BinaryOpType::Add:
if (sop->in()->dtype() == DataType::Int) {
// atomicAdd does not provide an overload for int64_t
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😮‍💨

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, Do you happen to know the reason that atomicAdd only has uint64_t but not for int64_t? Yet the max/min version has that?

programming guide discusses only floating point types... https://docs.nvidia.com/cuda/cuda-c-programming-guide/#atomicadd

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea.

code_ << "atomicAdd("
<< "reinterpret_cast<unsigned long long*>(&" << dst << "), "
<< "static_cast<unsigned long long>(" << src << "));\n";
} else {
code_ << "atomicAdd(" << "&" << dst << ", " << src << ");\n";
}
break;
case BinaryOpType::Max:
// CUDA doesn't provide atomicMax for float. Could be
// implemented using atomicCAS
NVF_ERROR(
isIntegralType(sop->src()->dtype()),
"Floating point max accumulation not supported");
code_ << "atomicMax(" << "&" << dst << ", " << src << ");\n";
break;
case BinaryOpType::Min:
// CUDA doesn't provide atomicMin for float. Could be
// implemented using atomicCAS
NVF_ERROR(
isIntegralType(sop->src()->dtype()),
"Floating point min accumulation not supported");
code_ << "atomicMin(" << "&" << dst << ", " << src << ");\n";
break;
default:
NVF_THROW("Unsupported accumulation op: ", sop->accumulateOp());
}
}

Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,12 @@ void IndexLowering::handle(const ScatterOp* sop) {
auto lowered_out = lowerDstIndex(sop->out(), override_index);

pushBack(IrBuilder::create<ScatterOp>(
sop->getScatterOpType(),
/*out=*/lowered_out,
/*self=*/lowered_out,
sop->dim(),
lowered_index,
lowered_src));
lowered_src,
sop->accumulate() ? std::optional(sop->accumulateOp()) : std::nullopt));
GpuLower::current()->propagateExprInfo(sop, back());
}

Expand Down
13 changes: 9 additions & 4 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ class ScatterOp : public Expr {
using Expr::Expr;
ScatterOp(
IrBuilderPasskey,
ScatterOpType type,
Val* out,
Val* self,
int64_t dim,
Val* index,
Val* src);
Val* src,
std::optional<BinaryOpType> accumulate_op = std::nullopt);

NVFUSER_DECLARE_CLONE_AND_CREATE

Expand Down Expand Up @@ -295,8 +295,13 @@ class ScatterOp : public Expr {

IterDomain* getIndexedID() const;

ScatterOpType getScatterOpType() const {
return attribute<ScatterOpType>(1);
bool accumulate() const {
return attribute<bool>(1);
}

BinaryOpType accumulateOp() const {
NVF_ERROR(accumulate());
return attribute<BinaryOpType>(2);
}
};

Expand Down
63 changes: 51 additions & 12 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,29 +286,36 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GatherOp)

ScatterOp::ScatterOp(
IrBuilderPasskey passkey,
ScatterOpType type,
Val* out,
Val* self,
int64_t dim,
Val* index,
Val* src)
Val* src,
std::optional<BinaryOpType> accumulate_op)
: Expr(passkey) {
addInput(self);
addInput(index);
addInput(src);
addOutput(out);
addDataAttribute(dim);
addDataAttribute(type);
// is this accumulate?
addDataAttribute(accumulate_op.has_value());
if (accumulate_op.has_value()) {
addDataAttribute(accumulate_op.value());
}
}

std::string ScatterOp::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << output(0)->toString() << "\n";
indent_size++;
indent(ss, indent_size) << " =" << getScatterOpType() << "(";
indent(ss, indent_size) << " = scatter(";
ss << "in = " << in()->toString() << ", dim = " << dim()
<< ", src = " << src()->toString() << ", idx = " << index()->toString()
<< " )\n";
<< ", src = " << src()->toString() << ", idx = " << index()->toString();
if (accumulate()) {
ss << ", accumulate = " << accumulateOp();
}
ss << " )\n";
return ss.str();
}

Expand All @@ -326,15 +333,47 @@ std::vector<PolymorphicValue> ScatterOp::evaluate(
const auto& input = inputs.at(0).as<at::Tensor>();
const auto& index = inputs.at(1).as<at::Tensor>();
auto dimension = dim();
if (src()->isA<TensorView>()) {
return {
at::scatter(input, dimension, index, inputs.at(2).as<at::Tensor>())};
} else {
return {at::scatter(
if (accumulate()) {
std::string accumulate_op_str;
switch (accumulateOp()) {
case BinaryOpType::Add:
accumulate_op_str = "sum";
break;
case BinaryOpType::Mul:
accumulate_op_str = "prod";
break;
case BinaryOpType::Max:
accumulate_op_str = "amax";
break;
case BinaryOpType::Min:
accumulate_op_str = "amin";
break;
default:
NVF_THROW("Unsupported accumulation op: ", accumulateOp());
}
// at::scatter_reduce doesn't seem to support scalar
// src. at::scatter does support but it seems it's deprecated and
// only supports add and multiply accumulation.
NVF_ERROR(
src()->isA<TensorView>(),
"at::scatter_reduce does not support scalar src argument");
return {at::scatter_reduce(
input,
dimension,
index,
PolymorphicValue_functions::toScalar(inputs.back()))};
inputs.at(2).as<at::Tensor>(),
accumulate_op_str)};
} else {
if (src()->isA<TensorView>()) {
return {
at::scatter(input, dimension, index, inputs.at(2).as<at::Tensor>())};
} else {
return {at::scatter(
input,
dimension,
index,
PolymorphicValue_functions::toScalar(inputs.at(2)))};
}
}
}

Expand Down
28 changes: 16 additions & 12 deletions csrc/ops/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,12 @@ TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) {
return out_tensor->as<TensorView>();
}

TensorView* scatterOp(
ScatterOpType type,
TensorView* scatter(
TensorView* self,
int64_t dim,
TensorView* index,
Val* src) {
Val* src,
std::optional<BinaryOpType> accumulate_op) {
auto self_dom = TensorDomain::noReductions(self->getLogicalDomain());
auto idx_dom = TensorDomain::noReductions(index->getLogicalDomain());

Expand Down Expand Up @@ -215,16 +215,20 @@ TensorView* scatterOp(
/*skip_loop_validation=*/true),
self->getDataType().value());

IrBuilder::create<ScatterOp>(type, out_tensor, self, dim, index, src);
return out_tensor->as<TensorView>();
}
if (accumulate_op.has_value()) {
NVF_ERROR(
accumulate_op.value() == BinaryOpType::Add ||
accumulate_op.value() == BinaryOpType::Mul ||
accumulate_op.value() == BinaryOpType::Max ||
accumulate_op.value() == BinaryOpType::Min,
"Unsupported accumulation op: ",
accumulate_op.value());
}

TensorView* scatter(
TensorView* self,
int64_t dim,
TensorView* index,
Val* src) {
return scatterOp(ScatterOpType::Set, self, dim, index, src);
IrBuilder::create<ScatterOp>(
out_tensor, self, dim, index, src, accumulate_op);

return out_tensor->as<TensorView>();
}

TensorView* takeAlongAxis(TensorView* inp, TensorView* index, int64_t dim) {
Expand Down
11 changes: 2 additions & 9 deletions csrc/ops/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ NVF_API TensorView* indexPutAccumulate(
// torch.gather
NVF_API TensorView* gather(TensorView* input, int64_t dim, TensorView* index);

// TODO: Revisit the interface design. ScatterOpType could be just BinaryOpType
NVF_API TensorView* scatterOp(
ScatterOpType type,
TensorView* self,
int64_t dim,
TensorView* index,
Val* src);

// Provides torch.scatter. It is designed to represent the ouf-of-place
// scatter operation, i.e., the returned tensor, out_tv, is defined as
// follows:
Expand Down Expand Up @@ -72,7 +64,8 @@ NVF_API TensorView* scatter(
TensorView* self,
int64_t dim,
TensorView* index,
Val* src);
Val* src,
std::optional<BinaryOpType> accumulate_op = std::nullopt);

//! numpy.take_along_axis
//! (https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html)
Expand Down
7 changes: 0 additions & 7 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1455,13 +1455,6 @@ std::ostream& operator<<(std::ostream& out, const BinaryOpType botype) {
return out << binary_op_type2string(botype);
}

std::ostream& operator<<(std::ostream& out, const ScatterOpType sotype) {
if (sotype == ScatterOpType::Set) {
return out << "scatter";
}
NVF_THROW("No scatterOp type found for scatterOp.");
}

std::ostream& operator<<(std::ostream& out, const TernaryOpType totype) {
return out << ternary_op_type2string(totype);
}
Expand Down
3 changes: 0 additions & 3 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,8 +650,6 @@ enum class BinaryOpType {
Complex
};

enum class ScatterOpType { Set };

enum class RNGOpType {
Uniform, // Uniform in [0, 1)
UniformRange, // Uniform in [low, high]
Expand Down Expand Up @@ -1006,7 +1004,6 @@ NVF_API std::ostream& operator<<(std::ostream&, const DataType);
std::ostream& operator<<(std::ostream&, const UnaryOpType);
NVF_API std::ostream& operator<<(std::ostream&, const BinaryOpType);
std::ostream& operator<<(std::ostream&, const TernaryOpType);
std::ostream& operator<<(std::ostream&, const ScatterOpType);
std::ostream& operator<<(std::ostream&, const RNGOpType);
NVF_API std::ostream& operator<<(std::ostream&, const ParallelType);
NVF_API std::ostream& operator<<(std::ostream&, const MemoryType);
Expand Down
35 changes: 27 additions & 8 deletions tests/cpp/test_moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class SgLangMoETest : public NVFuserFixtureParamTest<MoEConfig> {
};

TEST_P(SgLangMoETest, ComputeProblemSizes) {
if (manual_scheduling) {
GTEST_SKIP() << "No manual scheduling implemented";
}

auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);
Expand All @@ -78,16 +74,39 @@ TEST_P(SgLangMoETest, ComputeProblemSizes) {

auto tv3 = ones({IrBuilder::create<Val>(num_tokens * topk)}, DataType::Int);

auto tv4 = indexPutAccumulate(tv2, tv1, tv3);
auto tv4 = scatter(tv2, 0, tv1, tv3, BinaryOpType::Add);

fusion.addOutput(tv4);

auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
auto t0 = at::randint(0, num_experts, {num_tokens, topk}, options);

FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({t0});
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
if (manual_scheduling) {
auto tv4_cache = tv4->cacheBefore();

// Scheduling all tensors as 1D tensors
for (auto tv : fusion.allTvs()) {
tv->flatten();
tv->axis(0)->parallelize(ParallelType::TIDx);
}

tv2->setMemoryType(MemoryType::Shared);
tv2->setAllocationDomain(tv2->getLogicalDomain(), true);
tv4_cache->setMemoryType(MemoryType::Shared);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this scheduling, we are doing atomic write to shared memory, and then write to global memory afterwards.

For my own curiosity, I think we can also not add cacheBefore on tv4 and rely on atomicAdd directly to global memory and that should still work?!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that should work too.

tv4_cache->setAllocationDomain(tv4_cache->getLogicalDomain(), true);

KernelExecutor ke;
ke.compile(&fusion, {t0});

GTEST_SKIP() << "Missing predication. Fix pending: "
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update this part once #5107 is done.

"https://github.com/NVIDIA/Fuser/pull/5107";
auto outputs = ke.run({t0});
testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
} else {
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto outputs = executor_cache.runFusionWithInputs({t0});
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
}
}

TEST_P(SgLangMoETest, ComputeExpertOffsets) {
Expand Down
Loading