Skip to content

Commit 6048003

Browse files
authored
Lowering scatter accumulate (#4764)
Adds an optional accumulate parameter to `ScatterOp` so that it can be used both with and without accumulation. I'll look into consolidating `IndexPutAccumulateOp` as well in the future.
1 parent 84c46fe commit 6048003

File tree

10 files changed

+263
-64
lines changed

10 files changed

+263
-64
lines changed

csrc/codegen.cpp

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,13 +1303,73 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
13031303
}
13041304

13051305
void handle(const ScatterOp* sop) final {
1306-
// generate code like T_output[... T_index[...]] = op(T_src[...]);
1307-
if (sop->getScatterOpType() == ScatterOpType::Set) {
1308-
// When value of index_tv are not unique, the behavior of Set is
1309-
// non-deterministic
1310-
indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n";
1311-
} else {
1312-
NVF_THROW("unkown scatterOp");
1306+
if (sop->accumulate()) {
1307+
handleScatterAccumulate(sop);
1308+
return;
1309+
}
1310+
1311+
// Generate code like T_output[... T_index[...]] = op(T_src[...]);
1312+
//
1313+
// When value of index_tv are not unique, the behavior of Set is
1314+
// non-deterministic
1315+
indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n";
1316+
}
1317+
1318+
// Atomic-based accumulation. Only supported with integer data or
1319+
// non determinism is excplicitly permitted
1320+
void handleScatterAccumulate(const ScatterOp* sop) {
1321+
const bool non_deterministic = isFloatingPointType(sop->src()->dtype()) &&
1322+
(sop->accumulateOp() != BinaryOpType::Max ||
1323+
sop->accumulateOp() != BinaryOpType::Min);
1324+
1325+
NVF_ERROR(
1326+
!at::globalContext().deterministicAlgorithms() || !non_deterministic,
1327+
"Trying to use non-deterministic instructions even though "
1328+
"deterministic algorithm is requested: ",
1329+
sop->toString());
1330+
1331+
NVF_ERROR(
1332+
sop->src()->dtype() == DataType::Int ||
1333+
sop->src()->dtype() == DataType::Int32 ||
1334+
sop->src()->dtype() == DataType::Float ||
1335+
sop->src()->dtype() == DataType::Double,
1336+
"Data type not supported: ",
1337+
sop->src()->dtype());
1338+
1339+
const auto dst = gen(sop->out());
1340+
const auto src = gen(sop->src());
1341+
1342+
indent();
1343+
1344+
switch (sop->accumulateOp()) {
1345+
case BinaryOpType::Add:
1346+
if (sop->in()->dtype() == DataType::Int) {
1347+
// atomicAdd does not provide an overload for int64_t
1348+
code_ << "atomicAdd("
1349+
<< "reinterpret_cast<unsigned long long*>(&" << dst << "), "
1350+
<< "static_cast<unsigned long long>(" << src << "));\n";
1351+
} else {
1352+
code_ << "atomicAdd(" << "&" << dst << ", " << src << ");\n";
1353+
}
1354+
break;
1355+
case BinaryOpType::Max:
1356+
// CUDA doesn't provide atomicMax for float. Could be
1357+
// implemented using atomicCAS
1358+
NVF_ERROR(
1359+
isIntegralType(sop->src()->dtype()),
1360+
"Floating point max accumulation not supported");
1361+
code_ << "atomicMax(" << "&" << dst << ", " << src << ");\n";
1362+
break;
1363+
case BinaryOpType::Min:
1364+
// CUDA doesn't provide atomicMin for float. Could be
1365+
// implemented using atomicCAS
1366+
NVF_ERROR(
1367+
isIntegralType(sop->src()->dtype()),
1368+
"Floating point min accumulation not supported");
1369+
code_ << "atomicMin(" << "&" << dst << ", " << src << ");\n";
1370+
break;
1371+
default:
1372+
NVF_THROW("Unsupported accumulation op: ", sop->accumulateOp());
13131373
}
13141374
}
13151375

csrc/device_lower/pass/index.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,12 @@ void IndexLowering::handle(const ScatterOp* sop) {
367367
auto lowered_out = lowerDstIndex(sop->out(), override_index);
368368

369369
pushBack(IrBuilder::create<ScatterOp>(
370-
sop->getScatterOpType(),
371370
/*out=*/lowered_out,
372371
/*self=*/lowered_out,
373372
sop->dim(),
374373
lowered_index,
375-
lowered_src));
374+
lowered_src,
375+
sop->accumulate() ? std::optional(sop->accumulateOp()) : std::nullopt));
376376
GpuLower::current()->propagateExprInfo(sop, back());
377377
}
378378

csrc/ir/internal_nodes.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,12 @@ class ScatterOp : public Expr {
254254
using Expr::Expr;
255255
ScatterOp(
256256
IrBuilderPasskey,
257-
ScatterOpType type,
258257
Val* out,
259258
Val* self,
260259
int64_t dim,
261260
Val* index,
262-
Val* src);
261+
Val* src,
262+
std::optional<BinaryOpType> accumulate_op = std::nullopt);
263263

264264
NVFUSER_DECLARE_CLONE_AND_CREATE
265265

@@ -295,8 +295,13 @@ class ScatterOp : public Expr {
295295

296296
IterDomain* getIndexedID() const;
297297

298-
ScatterOpType getScatterOpType() const {
299-
return attribute<ScatterOpType>(1);
298+
bool accumulate() const {
299+
return attribute<bool>(1);
300+
}
301+
302+
BinaryOpType accumulateOp() const {
303+
NVF_ERROR(accumulate());
304+
return attribute<BinaryOpType>(2);
300305
}
301306
};
302307

csrc/ir/nodes.cpp

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -286,29 +286,36 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GatherOp)
286286

287287
ScatterOp::ScatterOp(
288288
IrBuilderPasskey passkey,
289-
ScatterOpType type,
290289
Val* out,
291290
Val* self,
292291
int64_t dim,
293292
Val* index,
294-
Val* src)
293+
Val* src,
294+
std::optional<BinaryOpType> accumulate_op)
295295
: Expr(passkey) {
296296
addInput(self);
297297
addInput(index);
298298
addInput(src);
299299
addOutput(out);
300300
addDataAttribute(dim);
301-
addDataAttribute(type);
301+
// is this accumulate?
302+
addDataAttribute(accumulate_op.has_value());
303+
if (accumulate_op.has_value()) {
304+
addDataAttribute(accumulate_op.value());
305+
}
302306
}
303307

304308
std::string ScatterOp::toString(int indent_size) const {
305309
std::stringstream ss;
306310
indent(ss, indent_size) << output(0)->toString() << "\n";
307311
indent_size++;
308-
indent(ss, indent_size) << " =" << getScatterOpType() << "(";
312+
indent(ss, indent_size) << " = scatter(";
309313
ss << "in = " << in()->toString() << ", dim = " << dim()
310-
<< ", src = " << src()->toString() << ", idx = " << index()->toString()
311-
<< " )\n";
314+
<< ", src = " << src()->toString() << ", idx = " << index()->toString();
315+
if (accumulate()) {
316+
ss << ", accumulate = " << accumulateOp();
317+
}
318+
ss << " )\n";
312319
return ss.str();
313320
}
314321

@@ -326,15 +333,47 @@ std::vector<PolymorphicValue> ScatterOp::evaluate(
326333
const auto& input = inputs.at(0).as<at::Tensor>();
327334
const auto& index = inputs.at(1).as<at::Tensor>();
328335
auto dimension = dim();
329-
if (src()->isA<TensorView>()) {
330-
return {
331-
at::scatter(input, dimension, index, inputs.at(2).as<at::Tensor>())};
332-
} else {
333-
return {at::scatter(
336+
if (accumulate()) {
337+
std::string accumulate_op_str;
338+
switch (accumulateOp()) {
339+
case BinaryOpType::Add:
340+
accumulate_op_str = "sum";
341+
break;
342+
case BinaryOpType::Mul:
343+
accumulate_op_str = "prod";
344+
break;
345+
case BinaryOpType::Max:
346+
accumulate_op_str = "amax";
347+
break;
348+
case BinaryOpType::Min:
349+
accumulate_op_str = "amin";
350+
break;
351+
default:
352+
NVF_THROW("Unsupported accumulation op: ", accumulateOp());
353+
}
354+
// at::scatter_reduce doesn't seem to support scalar
355+
// src. at::scatter does support but it seems it's deprecated and
356+
// only supports add and multiply accumulation.
357+
NVF_ERROR(
358+
src()->isA<TensorView>(),
359+
"at::scatter_reduce does not support scalar src argument");
360+
return {at::scatter_reduce(
334361
input,
335362
dimension,
336363
index,
337-
PolymorphicValue_functions::toScalar(inputs.back()))};
364+
inputs.at(2).as<at::Tensor>(),
365+
accumulate_op_str)};
366+
} else {
367+
if (src()->isA<TensorView>()) {
368+
return {
369+
at::scatter(input, dimension, index, inputs.at(2).as<at::Tensor>())};
370+
} else {
371+
return {at::scatter(
372+
input,
373+
dimension,
374+
index,
375+
PolymorphicValue_functions::toScalar(inputs.at(2)))};
376+
}
338377
}
339378
}
340379

csrc/ops/indexing.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,12 @@ TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) {
161161
return out_tensor->as<TensorView>();
162162
}
163163

164-
TensorView* scatterOp(
165-
ScatterOpType type,
164+
TensorView* scatter(
166165
TensorView* self,
167166
int64_t dim,
168167
TensorView* index,
169-
Val* src) {
168+
Val* src,
169+
std::optional<BinaryOpType> accumulate_op) {
170170
auto self_dom = TensorDomain::noReductions(self->getLogicalDomain());
171171
auto idx_dom = TensorDomain::noReductions(index->getLogicalDomain());
172172

@@ -215,16 +215,20 @@ TensorView* scatterOp(
215215
/*skip_loop_validation=*/true),
216216
self->getDataType().value());
217217

218-
IrBuilder::create<ScatterOp>(type, out_tensor, self, dim, index, src);
219-
return out_tensor->as<TensorView>();
220-
}
218+
if (accumulate_op.has_value()) {
219+
NVF_ERROR(
220+
accumulate_op.value() == BinaryOpType::Add ||
221+
accumulate_op.value() == BinaryOpType::Mul ||
222+
accumulate_op.value() == BinaryOpType::Max ||
223+
accumulate_op.value() == BinaryOpType::Min,
224+
"Unsupported accumulation op: ",
225+
accumulate_op.value());
226+
}
221227

222-
TensorView* scatter(
223-
TensorView* self,
224-
int64_t dim,
225-
TensorView* index,
226-
Val* src) {
227-
return scatterOp(ScatterOpType::Set, self, dim, index, src);
228+
IrBuilder::create<ScatterOp>(
229+
out_tensor, self, dim, index, src, accumulate_op);
230+
231+
return out_tensor->as<TensorView>();
228232
}
229233

230234
TensorView* takeAlongAxis(TensorView* inp, TensorView* index, int64_t dim) {

csrc/ops/indexing.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ NVF_API TensorView* indexPutAccumulate(
3232
// torch.gather
3333
NVF_API TensorView* gather(TensorView* input, int64_t dim, TensorView* index);
3434

35-
// TODO: Revisit the interface design. ScatterOpType could be just BinaryOpType
36-
NVF_API TensorView* scatterOp(
37-
ScatterOpType type,
38-
TensorView* self,
39-
int64_t dim,
40-
TensorView* index,
41-
Val* src);
42-
4335
// Provides torch.scatter. It is designed to represent the ouf-of-place
4436
// scatter operation, i.e., the returned tensor, out_tv, is defined as
4537
// follows:
@@ -72,7 +64,8 @@ NVF_API TensorView* scatter(
7264
TensorView* self,
7365
int64_t dim,
7466
TensorView* index,
75-
Val* src);
67+
Val* src,
68+
std::optional<BinaryOpType> accumulate_op = std::nullopt);
7669

7770
//! numpy.take_along_axis
7871
//! (https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html)

csrc/type.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,13 +1455,6 @@ std::ostream& operator<<(std::ostream& out, const BinaryOpType botype) {
14551455
return out << binary_op_type2string(botype);
14561456
}
14571457

1458-
std::ostream& operator<<(std::ostream& out, const ScatterOpType sotype) {
1459-
if (sotype == ScatterOpType::Set) {
1460-
return out << "scatter";
1461-
}
1462-
NVF_THROW("No scatterOp type found for scatterOp.");
1463-
}
1464-
14651458
std::ostream& operator<<(std::ostream& out, const TernaryOpType totype) {
14661459
return out << ternary_op_type2string(totype);
14671460
}

csrc/type.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,6 @@ enum class BinaryOpType {
650650
Complex
651651
};
652652

653-
enum class ScatterOpType { Set };
654-
655653
enum class RNGOpType {
656654
Uniform, // Uniform in [0, 1)
657655
UniformRange, // Uniform in [low, high]
@@ -1006,7 +1004,6 @@ NVF_API std::ostream& operator<<(std::ostream&, const DataType);
10061004
std::ostream& operator<<(std::ostream&, const UnaryOpType);
10071005
NVF_API std::ostream& operator<<(std::ostream&, const BinaryOpType);
10081006
std::ostream& operator<<(std::ostream&, const TernaryOpType);
1009-
std::ostream& operator<<(std::ostream&, const ScatterOpType);
10101007
std::ostream& operator<<(std::ostream&, const RNGOpType);
10111008
NVF_API std::ostream& operator<<(std::ostream&, const ParallelType);
10121009
NVF_API std::ostream& operator<<(std::ostream&, const MemoryType);

tests/cpp/test_moe.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ class SgLangMoETest : public NVFuserFixtureParamTest<MoEConfig> {
5959
};
6060

6161
TEST_P(SgLangMoETest, ComputeProblemSizes) {
62-
if (manual_scheduling) {
63-
GTEST_SKIP() << "No manual scheduling implemented";
64-
}
65-
6662
auto fusion_ptr = std::make_unique<Fusion>();
6763
Fusion& fusion = *fusion_ptr.get();
6864
FusionGuard fg(&fusion);
@@ -78,16 +74,39 @@ TEST_P(SgLangMoETest, ComputeProblemSizes) {
7874

7975
auto tv3 = ones({IrBuilder::create<Val>(num_tokens * topk)}, DataType::Int);
8076

81-
auto tv4 = indexPutAccumulate(tv2, tv1, tv3);
77+
auto tv4 = scatter(tv2, 0, tv1, tv3, BinaryOpType::Add);
8278

8379
fusion.addOutput(tv4);
8480

8581
auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
8682
auto t0 = at::randint(0, num_experts, {num_tokens, topk}, options);
8783

88-
FusionExecutorCache executor_cache(std::move(fusion_ptr));
89-
auto outputs = executor_cache.runFusionWithInputs({t0});
90-
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
84+
if (manual_scheduling) {
85+
auto tv4_cache = tv4->cacheBefore();
86+
87+
// Scheduling all tensors as 1D tensors
88+
for (auto tv : fusion.allTvs()) {
89+
tv->flatten();
90+
tv->axis(0)->parallelize(ParallelType::TIDx);
91+
}
92+
93+
tv2->setMemoryType(MemoryType::Shared);
94+
tv2->setAllocationDomain(tv2->getLogicalDomain(), true);
95+
tv4_cache->setMemoryType(MemoryType::Shared);
96+
tv4_cache->setAllocationDomain(tv4_cache->getLogicalDomain(), true);
97+
98+
KernelExecutor ke;
99+
ke.compile(&fusion, {t0});
100+
101+
GTEST_SKIP() << "Missing predication. Fix pending: "
102+
"https://github.com/NVIDIA/Fuser/pull/5107";
103+
auto outputs = ke.run({t0});
104+
testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__);
105+
} else {
106+
FusionExecutorCache executor_cache(std::move(fusion_ptr));
107+
auto outputs = executor_cache.runFusionWithInputs({t0});
108+
testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__);
109+
}
91110
}
92111

93112
TEST_P(SgLangMoETest, ComputeExpertOffsets) {

0 commit comments

Comments
 (0)