diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 50a0788418a..425bfbd8f29 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1303,13 +1303,73 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } void handle(const ScatterOp* sop) final { - // generate code like T_output[... T_index[...]] = op(T_src[...]); - if (sop->getScatterOpType() == ScatterOpType::Set) { - // When value of index_tv are not unique, the behavior of Set is - // non-deterministic - indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n"; - } else { - NVF_THROW("unkown scatterOp"); + if (sop->accumulate()) { + handleScatterAccumulate(sop); + return; + } + + // Generate code like T_output[... T_index[...]] = op(T_src[...]); + // + // When value of index_tv are not unique, the behavior of Set is + // non-deterministic + indent() << gen(sop->out()) << " = " << gen(sop->src()) << ";\n"; + } + + // Atomic-based accumulation. Only supported with integer data or + // non determinism is excplicitly permitted + void handleScatterAccumulate(const ScatterOp* sop) { + const bool non_deterministic = isFloatingPointType(sop->src()->dtype()) && + (sop->accumulateOp() != BinaryOpType::Max || + sop->accumulateOp() != BinaryOpType::Min); + + NVF_ERROR( + !at::globalContext().deterministicAlgorithms() || !non_deterministic, + "Trying to use non-deterministic instructions even though " + "deterministic algorithm is requested: ", + sop->toString()); + + NVF_ERROR( + sop->src()->dtype() == DataType::Int || + sop->src()->dtype() == DataType::Int32 || + sop->src()->dtype() == DataType::Float || + sop->src()->dtype() == DataType::Double, + "Data type not supported: ", + sop->src()->dtype()); + + const auto dst = gen(sop->out()); + const auto src = gen(sop->src()); + + indent(); + + switch (sop->accumulateOp()) { + case BinaryOpType::Add: + if (sop->in()->dtype() == DataType::Int) { + // atomicAdd does not provide an overload for int64_t + code_ << "atomicAdd(" + << "reinterpret_cast(&" << dst << "), " + << "static_cast(" << src << "));\n"; + } else { + code_ << "atomicAdd(" << "&" << dst << ", " << src << ");\n"; + } + break; + case BinaryOpType::Max: + // CUDA doesn't provide atomicMax for float. Could be + // implemented using atomicCAS + NVF_ERROR( + isIntegralType(sop->src()->dtype()), + "Floating point max accumulation not supported"); + code_ << "atomicMax(" << "&" << dst << ", " << src << ");\n"; + break; + case BinaryOpType::Min: + // CUDA doesn't provide atomicMin for float. Could be + // implemented using atomicCAS + NVF_ERROR( + isIntegralType(sop->src()->dtype()), + "Floating point min accumulation not supported"); + code_ << "atomicMin(" << "&" << dst << ", " << src << ");\n"; + break; + default: + NVF_THROW("Unsupported accumulation op: ", sop->accumulateOp()); } } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index fc31951f026..1c6eb5d2455 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -367,12 +367,12 @@ void IndexLowering::handle(const ScatterOp* sop) { auto lowered_out = lowerDstIndex(sop->out(), override_index); pushBack(IrBuilder::create( - sop->getScatterOpType(), /*out=*/lowered_out, /*self=*/lowered_out, sop->dim(), lowered_index, - lowered_src)); + lowered_src, + sop->accumulate() ? std::optional(sop->accumulateOp()) : std::nullopt)); GpuLower::current()->propagateExprInfo(sop, back()); } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 757c5826c62..dc9c4dca6e2 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -254,12 +254,12 @@ class ScatterOp : public Expr { using Expr::Expr; ScatterOp( IrBuilderPasskey, - ScatterOpType type, Val* out, Val* self, int64_t dim, Val* index, - Val* src); + Val* src, + std::optional accumulate_op = std::nullopt); NVFUSER_DECLARE_CLONE_AND_CREATE @@ -295,8 +295,13 @@ class ScatterOp : public Expr { IterDomain* getIndexedID() const; - ScatterOpType getScatterOpType() const { - return attribute(1); + bool accumulate() const { + return attribute(1); + } + + BinaryOpType accumulateOp() const { + NVF_ERROR(accumulate()); + return attribute(2); } }; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 5d81516346f..72d9713b239 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -286,29 +286,36 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GatherOp) ScatterOp::ScatterOp( IrBuilderPasskey passkey, - ScatterOpType type, Val* out, Val* self, int64_t dim, Val* index, - Val* src) + Val* src, + std::optional accumulate_op) : Expr(passkey) { addInput(self); addInput(index); addInput(src); addOutput(out); addDataAttribute(dim); - addDataAttribute(type); + // is this accumulate? + addDataAttribute(accumulate_op.has_value()); + if (accumulate_op.has_value()) { + addDataAttribute(accumulate_op.value()); + } } std::string ScatterOp::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << output(0)->toString() << "\n"; indent_size++; - indent(ss, indent_size) << " =" << getScatterOpType() << "("; + indent(ss, indent_size) << " = scatter("; ss << "in = " << in()->toString() << ", dim = " << dim() - << ", src = " << src()->toString() << ", idx = " << index()->toString() - << " )\n"; + << ", src = " << src()->toString() << ", idx = " << index()->toString(); + if (accumulate()) { + ss << ", accumulate = " << accumulateOp(); + } + ss << " )\n"; return ss.str(); } @@ -326,15 +333,47 @@ std::vector ScatterOp::evaluate( const auto& input = inputs.at(0).as(); const auto& index = inputs.at(1).as(); auto dimension = dim(); - if (src()->isA()) { - return { - at::scatter(input, dimension, index, inputs.at(2).as())}; - } else { - return {at::scatter( + if (accumulate()) { + std::string accumulate_op_str; + switch (accumulateOp()) { + case BinaryOpType::Add: + accumulate_op_str = "sum"; + break; + case BinaryOpType::Mul: + accumulate_op_str = "prod"; + break; + case BinaryOpType::Max: + accumulate_op_str = "amax"; + break; + case BinaryOpType::Min: + accumulate_op_str = "amin"; + break; + default: + NVF_THROW("Unsupported accumulation op: ", accumulateOp()); + } + // at::scatter_reduce doesn't seem to support scalar + // src. at::scatter does support but it seems it's deprecated and + // only supports add and multiply accumulation. + NVF_ERROR( + src()->isA(), + "at::scatter_reduce does not support scalar src argument"); + return {at::scatter_reduce( input, dimension, index, - PolymorphicValue_functions::toScalar(inputs.back()))}; + inputs.at(2).as(), + accumulate_op_str)}; + } else { + if (src()->isA()) { + return { + at::scatter(input, dimension, index, inputs.at(2).as())}; + } else { + return {at::scatter( + input, + dimension, + index, + PolymorphicValue_functions::toScalar(inputs.at(2)))}; + } } } diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index 97dc5cdac0e..778c0b4f7a0 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -161,12 +161,12 @@ TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { return out_tensor->as(); } -TensorView* scatterOp( - ScatterOpType type, +TensorView* scatter( TensorView* self, int64_t dim, TensorView* index, - Val* src) { + Val* src, + std::optional accumulate_op) { auto self_dom = TensorDomain::noReductions(self->getLogicalDomain()); auto idx_dom = TensorDomain::noReductions(index->getLogicalDomain()); @@ -215,16 +215,20 @@ TensorView* scatterOp( /*skip_loop_validation=*/true), self->getDataType().value()); - IrBuilder::create(type, out_tensor, self, dim, index, src); - return out_tensor->as(); -} + if (accumulate_op.has_value()) { + NVF_ERROR( + accumulate_op.value() == BinaryOpType::Add || + accumulate_op.value() == BinaryOpType::Mul || + accumulate_op.value() == BinaryOpType::Max || + accumulate_op.value() == BinaryOpType::Min, + "Unsupported accumulation op: ", + accumulate_op.value()); + } -TensorView* scatter( - TensorView* self, - int64_t dim, - TensorView* index, - Val* src) { - return scatterOp(ScatterOpType::Set, self, dim, index, src); + IrBuilder::create( + out_tensor, self, dim, index, src, accumulate_op); + + return out_tensor->as(); } TensorView* takeAlongAxis(TensorView* inp, TensorView* index, int64_t dim) { diff --git a/csrc/ops/indexing.h b/csrc/ops/indexing.h index 6d8a5d28df7..b61cec4f29d 100644 --- a/csrc/ops/indexing.h +++ b/csrc/ops/indexing.h @@ -32,14 +32,6 @@ NVF_API TensorView* indexPutAccumulate( // torch.gather NVF_API TensorView* gather(TensorView* input, int64_t dim, TensorView* index); -// TODO: Revisit the interface design. ScatterOpType could be just BinaryOpType -NVF_API TensorView* scatterOp( - ScatterOpType type, - TensorView* self, - int64_t dim, - TensorView* index, - Val* src); - // Provides torch.scatter. It is designed to represent the ouf-of-place // scatter operation, i.e., the returned tensor, out_tv, is defined as // follows: @@ -72,7 +64,8 @@ NVF_API TensorView* scatter( TensorView* self, int64_t dim, TensorView* index, - Val* src); + Val* src, + std::optional accumulate_op = std::nullopt); //! numpy.take_along_axis //! (https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html) diff --git a/csrc/type.cpp b/csrc/type.cpp index 887732be54e..5376fb3ccc2 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -1455,13 +1455,6 @@ std::ostream& operator<<(std::ostream& out, const BinaryOpType botype) { return out << binary_op_type2string(botype); } -std::ostream& operator<<(std::ostream& out, const ScatterOpType sotype) { - if (sotype == ScatterOpType::Set) { - return out << "scatter"; - } - NVF_THROW("No scatterOp type found for scatterOp."); -} - std::ostream& operator<<(std::ostream& out, const TernaryOpType totype) { return out << ternary_op_type2string(totype); } diff --git a/csrc/type.h b/csrc/type.h index 4be1c8fb2a7..2bf5fda84ec 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -650,8 +650,6 @@ enum class BinaryOpType { Complex }; -enum class ScatterOpType { Set }; - enum class RNGOpType { Uniform, // Uniform in [0, 1) UniformRange, // Uniform in [low, high] @@ -1006,7 +1004,6 @@ NVF_API std::ostream& operator<<(std::ostream&, const DataType); std::ostream& operator<<(std::ostream&, const UnaryOpType); NVF_API std::ostream& operator<<(std::ostream&, const BinaryOpType); std::ostream& operator<<(std::ostream&, const TernaryOpType); -std::ostream& operator<<(std::ostream&, const ScatterOpType); std::ostream& operator<<(std::ostream&, const RNGOpType); NVF_API std::ostream& operator<<(std::ostream&, const ParallelType); NVF_API std::ostream& operator<<(std::ostream&, const MemoryType); diff --git a/tests/cpp/test_moe.cpp b/tests/cpp/test_moe.cpp index 80b68f12544..dbb02ec8e43 100644 --- a/tests/cpp/test_moe.cpp +++ b/tests/cpp/test_moe.cpp @@ -59,10 +59,6 @@ class SgLangMoETest : public NVFuserFixtureParamTest { }; TEST_P(SgLangMoETest, ComputeProblemSizes) { - if (manual_scheduling) { - GTEST_SKIP() << "No manual scheduling implemented"; - } - auto fusion_ptr = std::make_unique(); Fusion& fusion = *fusion_ptr.get(); FusionGuard fg(&fusion); @@ -78,16 +74,39 @@ TEST_P(SgLangMoETest, ComputeProblemSizes) { auto tv3 = ones({IrBuilder::create(num_tokens * topk)}, DataType::Int); - auto tv4 = indexPutAccumulate(tv2, tv1, tv3); + auto tv4 = scatter(tv2, 0, tv1, tv3, BinaryOpType::Add); fusion.addOutput(tv4); auto options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); auto t0 = at::randint(0, num_experts, {num_tokens, topk}, options); - FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({t0}); - testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); + if (manual_scheduling) { + auto tv4_cache = tv4->cacheBefore(); + + // Scheduling all tensors as 1D tensors + for (auto tv : fusion.allTvs()) { + tv->flatten(); + tv->axis(0)->parallelize(ParallelType::TIDx); + } + + tv2->setMemoryType(MemoryType::Shared); + tv2->setAllocationDomain(tv2->getLogicalDomain(), true); + tv4_cache->setMemoryType(MemoryType::Shared); + tv4_cache->setAllocationDomain(tv4_cache->getLogicalDomain(), true); + + KernelExecutor ke; + ke.compile(&fusion, {t0}); + + GTEST_SKIP() << "Missing predication. Fix pending: " + "https://github.com/NVIDIA/Fuser/pull/5107"; + auto outputs = ke.run({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + } else { + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outputs, {t0}, __LINE__, __FILE__); + } } TEST_P(SgLangMoETest, ComputeExpertOffsets) { diff --git a/tests/cpp/test_scatter.cpp b/tests/cpp/test_scatter.cpp index c41d7d945c6..9b349a30e0a 100644 --- a/tests/cpp/test_scatter.cpp +++ b/tests/cpp/test_scatter.cpp @@ -393,4 +393,93 @@ INSTANTIATE_TEST_SUITE_P( return os.str(); }); +class ScatterAccumulateTest + : public NVFuserFixtureParamTest< + std::tuple> { + protected: + void SetUp() override { + NVFuserTest::SetUp(); + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + std::tie(m, n, dtype, accumulate_op) = GetParam(); + } + + protected: + int64_t m = 8; + int64_t n = 128; + PrimDataType dtype = PrimDataType::Int; + BinaryOpType accumulate_op = BinaryOpType::Add; +}; + +TEST_P(ScatterAccumulateTest, BlockParallelScatterAccumulate) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto tv0 = makeContigConcreteTensor({m}, dtype); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor({n}, DataType::Int); + fusion.addInput(tv1); + auto tv2 = makeContigConcreteTensor({n}, dtype); + fusion.addInput(tv2); + + auto tv3 = set(tv0); + auto tv4 = scatter(tv3, 0, tv1, tv2, accumulate_op); + + auto tv5 = set(tv4); + fusion.addOutput(tv5); + + tv3->axis(0)->parallelize(ParallelType::TIDx); + tv4->axis(0)->parallelize(ParallelType::TIDx); + tv5->axis(0)->parallelize(ParallelType::TIDx); + + // Scatter input must use the same memory as the output + tv3->setMemoryType(MemoryType::Shared); + tv3->setAllocationDomain(tv3->getLogicalDomain(), true); + tv4->setMemoryType(MemoryType::Shared); + tv4->setAllocationDomain(tv4->getLogicalDomain(), true); + + auto options = + at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0); + auto options_int = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + auto t0 = isIntegralType(dtype) ? at::randint(0, 100, {m}, options) + : at::randn({m}, options); + auto t1 = at::randint(0, m, {n}, options_int); + auto t2 = isIntegralType(dtype) ? at::randint(0, 100, {n}, options) + : at::randn({n}, options); + + KernelExecutor ke; + if (isFloatingPointType(dtype) && + (accumulate_op == BinaryOpType::Max || + accumulate_op == BinaryOpType::Min)) { + EXPECT_THAT( + [&]() { ke.compile(&fusion, {t0, t1, t2}); }, + testing::ThrowsMessage( + testing::HasSubstr("accumulation not supported"))); + } else { + ke.compile(&fusion, {t0, t1, t2}); + auto outputs = ke.run({t0, t1, t2}); + testValidate(&fusion, outputs, {t0, t1, t2}, __LINE__, __FILE__); + } +} + +INSTANTIATE_TEST_SUITE_P( + , + ScatterAccumulateTest, + testing::Combine( + testing::Values(8, 32), + testing::Values(8, 32, 128), + testing::Values(PrimDataType::Int, PrimDataType::Float), + testing::Values( + BinaryOpType::Add, + BinaryOpType::Max, + BinaryOpType::Min)), + [](const testing::TestParamInfo< + std::tuple>& info) { + std::stringstream ss; + ss << std::get<0>(info.param) << "_" << std::get<1>(info.param) << "_" + << std::get<2>(info.param) << "_" << std::get<3>(info.param); + return ss.str(); + }); + } // namespace nvfuser