Skip to content

Commit 3c58c0b

Browse files
Hard fail for multi-result op if unreduced axes among results are not all same.
PiperOrigin-RevId: 807627730
1 parent c18df47 commit 3c58c0b

File tree

4 files changed

+27
-8
lines changed

4 files changed

+27
-8
lines changed

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,10 @@ bool differentOperandShardingFromFirstResult(Operation* op) {
963963
});
964964
}
965965

966+
ArrayRef<AxisRefAttr> getUnreducedAxes(TensorShardingAttr sharding) {
967+
return sharding ? sharding.getUnreducedAxes() : ArrayRef<AxisRefAttr>();
968+
}
969+
966970
void insertExplicitReshardsOnOp(Operation* op,
967971
ArrayRef<TensorShardingAttr> inShardings,
968972
ArrayRef<TensorShardingAttr> outShardings,

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ std::optional<ArrayRef<AxisRefAttr>> getFactorSharding(
8282
// operand shardings. If `op` does not have any results, returns false;
8383
bool differentOperandShardingFromFirstResult(Operation* op);
8484

85+
// Returns unreduced axes of given `sharding`. If `sharding` is null, returns
86+
// empty axes.
87+
ArrayRef<AxisRefAttr> getUnreducedAxes(TensorShardingAttr sharding);
88+
8589
// Inserts explicit reshards on the operands and results of `op` such that the
8690
// sharding of `op` is compatible with its sharding rule.
8791
//

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,17 @@ void insertAllReduceOnOpIfUnreducedToReplicated(
333333
return;
334334
}
335335

336+
TensorShardingAttr firstResultSharding = getSharding(op->getResult(0));
337+
if (op->getNumResults() > 1) {
338+
ArrayRef<AxisRefAttr> firstResultUnreducedAxes =
339+
getUnreducedAxes(firstResultSharding);
340+
for (OpResult result : op->getResults().drop_front()) {
341+
SDY_CHECK(firstResultUnreducedAxes ==
342+
getUnreducedAxes(getSharding(result)))
343+
<< "Unreduced axes mismatch between results for multi-result op.";
344+
}
345+
}
346+
336347
// For each operand that has unreduced axes, insert an all-reduce if
337348
// any of the unreduced axes isn't unreduced in the target sharding.
338349
//
@@ -341,9 +352,8 @@ void insertAllReduceOnOpIfUnreducedToReplicated(
341352
rewriter.setInsertionPoint(op);
342353
for (OpOperand& operand : op->getOpOperands()) {
343354
if (TensorShardingAttr inSharding = getSharding(operand.get())) {
344-
insertAllReduceIfUnreducedToReplicated(operand, inSharding,
345-
getSharding(op->getResult(0)),
346-
symbolTable, rewriter);
355+
insertAllReduceIfUnreducedToReplicated(
356+
operand, inSharding, firstResultSharding, symbolTable, rewriter);
347357
}
348358
}
349359
}

shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/unreduced.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,16 +134,17 @@ func.func @reduce_multiple_results_unreduced(
134134
%arg0: tensor<2x64x13xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}, {}, {}]>},
135135
%arg1: tensor<2x64x13xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"x", "y"}, {}, {}]>})
136136
-> (tensor<64xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}], unreduced={"x"}>},
137-
tensor<64xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}], unreduced={"y"}>}) {
137+
tensor<64xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}], unreduced={"x":(1)2}>}) {
138138
%0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
139139
%1 = stablehlo.constant dense<0> : tensor<i32>
140140
// CHECK: %[[REDUCE:.*]]:2 = stablehlo.reduce(%arg0 init: %cst), (%arg1 init: %c) across dimensions = [0, 2]
141-
// CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"x"}>, <@mesh, [{}], unreduced={"y"}>]>}
141+
// CHECK-SAME: {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"x"}>, <@mesh, [{}], unreduced={"x"}>]>}
142142
// CHECK: %[[ALL_REDUCE1:.*]] = sdy.all_reduce {"y"} %[[REDUCE]]#0 out_sharding=<@mesh, [{}], unreduced={"x"}> : tensor<64xf32>
143-
// CHECK-NEXT: %[[ALL_REDUCE2:.*]] = sdy.all_reduce {"x"} %[[REDUCE]]#1 out_sharding=<@mesh, [{}], unreduced={"y"}> : tensor<64xi32>
144-
// CHECK-NEXT: return %[[ALL_REDUCE1]], %[[ALL_REDUCE2]] : tensor<64xf32>, tensor<64xi32>
143+
// CHECK-NEXT: %[[ALL_REDUCE2:.*]] = sdy.all_reduce {"y"} %[[REDUCE]]#1 out_sharding=<@mesh, [{}], unreduced={"x"}> : tensor<64xi32>
144+
// CHECK-NEXT: %[[ALL_REDUCE3:.*]] = sdy.all_reduce {"x":(2)2} %[[ALL_REDUCE2]] out_sharding=<@mesh, [{}], unreduced={"x":(1)2}> : tensor<64xi32>
145+
// CHECK-NEXT: return %[[ALL_REDUCE1]], %[[ALL_REDUCE3]] : tensor<64xf32>, tensor<64xi32>
145146
%2:2 = stablehlo.reduce(%arg0 init: %0), (%arg1 init: %1) across dimensions = [0, 2]
146-
{sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"x"}>, <@mesh, [{}], unreduced={"y"}>]>} :
147+
{sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}], unreduced={"x"}>, <@mesh, [{}], unreduced={"x"}>]>} :
147148
(tensor<2x64x13xf32>, tensor<2x64x13xi32>, tensor<f32>, tensor<i32>) -> (tensor<64xf32>, tensor<64xi32>)
148149
reducer(%arg2: tensor<f32>, %arg4: tensor<f32>) (%arg3: tensor<i32>, %arg5: tensor<i32>) {
149150
%3 = stablehlo.add %arg2, %arg4 : tensor<f32>

0 commit comments

Comments
 (0)