@@ -1303,13 +1303,73 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
13031303 }
13041304
13051305 void handle (const ScatterOp* sop) final {
1306- // generate code like T_output[... T_index[...]] = op(T_src[...]);
1307- if (sop->getScatterOpType () == ScatterOpType::Set) {
1308- // When value of index_tv are not unique, the behavior of Set is
1309- // non-deterministic
1310- indent () << gen (sop->out ()) << " = " << gen (sop->src ()) << " ;\n " ;
1311- } else {
1312- NVF_THROW (" unkown scatterOp" );
1306+ if (sop->accumulate ()) {
1307+ handleScatterAccumulate (sop);
1308+ return ;
1309+ }
1310+
1311+ // Generate code like T_output[... T_index[...]] = op(T_src[...]);
1312+ //
1313+ // When value of index_tv are not unique, the behavior of Set is
1314+ // non-deterministic
1315+ indent () << gen (sop->out ()) << " = " << gen (sop->src ()) << " ;\n " ;
1316+ }
1317+
1318+ // Atomic-based accumulation. Only supported with integer data or
1319+ // non determinism is excplicitly permitted
1320+ void handleScatterAccumulate (const ScatterOp* sop) {
1321+ const bool non_deterministic = isFloatingPointType (sop->src ()->dtype ()) &&
1322+ (sop->accumulateOp () != BinaryOpType::Max ||
1323+ sop->accumulateOp () != BinaryOpType::Min);
1324+
1325+ NVF_ERROR (
1326+ !at::globalContext ().deterministicAlgorithms () || !non_deterministic,
1327+ " Trying to use non-deterministic instructions even though "
1328+ " deterministic algorithm is requested: " ,
1329+ sop->toString ());
1330+
1331+ NVF_ERROR (
1332+ sop->src ()->dtype () == DataType::Int ||
1333+ sop->src ()->dtype () == DataType::Int32 ||
1334+ sop->src ()->dtype () == DataType::Float ||
1335+ sop->src ()->dtype () == DataType::Double,
1336+ " Data type not supported: " ,
1337+ sop->src ()->dtype ());
1338+
1339+ const auto dst = gen (sop->out ());
1340+ const auto src = gen (sop->src ());
1341+
1342+ indent ();
1343+
1344+ switch (sop->accumulateOp ()) {
1345+ case BinaryOpType::Add:
1346+ if (sop->in ()->dtype () == DataType::Int) {
1347+ // atomicAdd does not provide an overload for int64_t
1348+ code_ << " atomicAdd("
1349+ << " reinterpret_cast<unsigned long long*>(&" << dst << " ), "
1350+ << " static_cast<unsigned long long>(" << src << " ));\n " ;
1351+ } else {
1352+ code_ << " atomicAdd(" << " &" << dst << " , " << src << " );\n " ;
1353+ }
1354+ break ;
1355+ case BinaryOpType::Max:
1356+ // CUDA doesn't provide atomicMax for float. Could be
1357+ // implemented using atomicCAS
1358+ NVF_ERROR (
1359+ isIntegralType (sop->src ()->dtype ()),
1360+ " Floating point max accumulation not supported" );
1361+ code_ << " atomicMax(" << " &" << dst << " , " << src << " );\n " ;
1362+ break ;
1363+ case BinaryOpType::Min:
1364+ // CUDA doesn't provide atomicMin for float. Could be
1365+ // implemented using atomicCAS
1366+ NVF_ERROR (
1367+ isIntegralType (sop->src ()->dtype ()),
1368+ " Floating point min accumulation not supported" );
1369+ code_ << " atomicMin(" << " &" << dst << " , " << src << " );\n " ;
1370+ break ;
1371+ default :
1372+ NVF_THROW (" Unsupported accumulation op: " , sop->accumulateOp ());
13131373 }
13141374 }
13151375
0 commit comments