Skip to content

Commit 36c83a5

Browse files
Varchocopybara-github
authored andcommitted
Enable PropagationEdgesAttr verification for sdy.data_flow_edge operations.
PiperOrigin-RevId: 792322564
1 parent 2e36a94 commit 36c83a5

File tree

3 files changed

+102
-41
lines changed

3 files changed

+102
-41
lines changed

shardy/dialect/sdy/ir/verifiers.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,12 +1043,6 @@ SmallVector<TensorShardingAttr> getShardingsReferenceByPropagationEdge(
10431043

10441044
LogicalResult verifyPropagationEdgesShardingAttr(
10451045
PropagationEdgesAttr propagationEdges, Operation* op) {
1046-
// TODO(b/429645141): add PropagationEdgesAttr verification for
1047-
// `DataFlowEdgeOp`
1048-
if (isa<DataFlowEdgeOp>(op)) {
1049-
return success();
1050-
}
1051-
10521046
SmallVector<TensorShardingAttr> shardings =
10531047
getShardingsReferenceByPropagationEdge(propagationEdges, op);
10541048

shardy/dialect/sdy/transforms/propagation/debugging/source_sharding.cc

Lines changed: 93 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -361,46 +361,71 @@ PropagationEdgesAttr createPropagationEdges(Operation* op,
361361
MLIRContext* context) {
362362
Builder builder(context);
363363
StepToAxisPropagationDetailsMap perStepEdgesForAxis;
364+
365+
// Build a temporary map to keep track of the edge with the minimum step
366+
// index for each source/target/axis combination. Due to sharding constraints
367+
// we have logic to specifically add edges outside of propagation. If similar
368+
// edges are then re-included during propagation, we need to filter them out.
369+
llvm::DenseMap<
370+
AxisRefAttr,
371+
llvm::DenseMap<EdgeValueRefAttr,
372+
llvm::DenseMap<EdgeValueRefAttr, PropagationEdge>>>
373+
minStepEdges;
364374
for (const auto& [axisRef, edges] : axisToEdges) {
365375
for (const PropagationEdge& edge : edges) {
366-
auto source =
376+
auto sourceAttr =
367377
EdgeValueRefAttr::get(context, edge.source.type, edge.source.index);
368-
auto target =
378+
auto targetAttr =
369379
EdgeValueRefAttr::get(context, edge.target.type, edge.target.index);
370-
perStepEdgesForAxis[edge.propagationStep][axisRef][source].insert(target);
380+
auto& sourceMap = minStepEdges[axisRef];
381+
auto& targetMap = sourceMap[sourceAttr];
382+
auto [it, inserted] = targetMap.try_emplace(targetAttr, edge);
383+
if (!inserted) {
384+
if (edge.propagationStep < it->second.propagationStep) {
385+
it->second = edge;
386+
}
387+
}
388+
}
389+
}
390+
391+
// Regroup the edges (with the minimum step index per Axis) by step index.
392+
for (const auto& [axisRef, sourceMap] : minStepEdges) {
393+
for (const auto& [sourceAttr, targetMap] : sourceMap) {
394+
for (const auto& [targetAttr, edge] : targetMap) {
395+
perStepEdgesForAxis[edge.propagationStep][axisRef][sourceAttr].insert(
396+
targetAttr);
397+
}
371398
}
372399
}
373400

374401
SmallVector<PropagationOneStepAttr> perStepEdges;
375402
for (const auto& [step, edgesForAxis] : perStepEdgesForAxis) {
376-
SmallVector<AxisToPropagationDetailsAttr> axis_entries;
403+
SmallVector<AxisToPropagationDetailsAttr> axisEntries;
377404
for (const auto& [axisRef, edges] : edgesForAxis) {
378-
// There should only be one source in the edge map.
379-
assert(edges.size() == 1);
380-
EdgeValueRefAttr source = edges.begin()->first;
381-
DenseSet<EdgeValueRefAttr> targets = edges.begin()->second;
382-
// Sort the targets for deterministic ordering in the output attr.
383-
SmallVector<EdgeValueRefAttr> targetsArray(targets.begin(),
384-
targets.end());
385-
llvm::stable_sort(targetsArray,
386-
[](EdgeValueRefAttr a, EdgeValueRefAttr b) {
387-
if (a.getType() == b.getType()) {
388-
return a.getIndex() < b.getIndex();
389-
}
390-
return a.getType() < b.getType();
391-
});
392-
AxisToPropagationDetailsAttr axisToPropagationDetails =
393-
AxisToPropagationDetailsAttr::get(context, axisRef, source,
394-
targetsArray);
395-
axis_entries.push_back(axisToPropagationDetails);
405+
for (const auto& [source, targets] : edges) {
406+
// Sort the targets for deterministic ordering in the output attr.
407+
SmallVector<EdgeValueRefAttr> targetsArray(targets.begin(),
408+
targets.end());
409+
llvm::stable_sort(targetsArray,
410+
[](EdgeValueRefAttr a, EdgeValueRefAttr b) {
411+
if (a.getType() == b.getType()) {
412+
return a.getIndex() < b.getIndex();
413+
}
414+
return a.getType() < b.getType();
415+
});
416+
AxisToPropagationDetailsAttr axisToPropagationDetails =
417+
AxisToPropagationDetailsAttr::get(context, axisRef, source,
418+
targetsArray);
419+
axisEntries.push_back(axisToPropagationDetails);
420+
}
396421
}
397422
// Sort the axes by name for deterministic ordering in the output attr.
398-
llvm::stable_sort(axis_entries, [](AxisToPropagationDetailsAttr a,
399-
AxisToPropagationDetailsAttr b) {
423+
llvm::stable_sort(axisEntries, [](AxisToPropagationDetailsAttr a,
424+
AxisToPropagationDetailsAttr b) {
400425
return a.getAxisName() < b.getAxisName();
401426
});
402427
perStepEdges.push_back(
403-
PropagationOneStepAttr::get(context, step, axis_entries));
428+
PropagationOneStepAttr::get(context, step, axisEntries));
404429
}
405430

406431
// Sort the edges by step index.
@@ -666,6 +691,47 @@ void prepareFuncResultToEdgesHandler(
666691
}
667692
}
668693

694+
void prepareShardingConstraintToEdgesHandler(
695+
ModuleOp moduleOp, OperationToEdgesMap& operationToEdgesMap) {
696+
moduleOp.walk([&](ShardingConstraintOp shardingConstraintOp) {
697+
auto sharding = shardingConstraintOp.getSharding();
698+
if (!sharding) {
699+
return;
700+
}
701+
for (DimensionShardingAttr dimSharding : sharding.getDimShardings()) {
702+
for (AxisRefAttr axisRef : dimSharding.getAxes()) {
703+
operationToEdgesMap[shardingConstraintOp][axisRef].push_back(
704+
PropagationEdge{/*source=*/EdgeNode{EdgeNodeType::RESULT, 0},
705+
/*target=*/EdgeNode{EdgeNodeType::OPERAND, 0},
706+
/*propagationStep=*/0});
707+
}
708+
}
709+
});
710+
711+
// Input sources of `ManualComputationOp` act as a sharding constraint (and
712+
// the ApplyShardingConstrains pass treats them as such). Due to this, we need
713+
// to create appropriate propagation edges for them.
714+
moduleOp.walk([&](ManualComputationOp manualComputationOp) {
715+
int64_t i = 0;
716+
for (const auto& sharding :
717+
manualComputationOp.getInShardings().getShardings()) {
718+
auto edgeOp =
719+
DataFlowEdgeOp::lookup(manualComputationOp.getBody().getArgument(i));
720+
assert(edgeOp);
721+
for (DimensionShardingAttr dimSharding : sharding.getDimShardings()) {
722+
for (AxisRefAttr axisRef : dimSharding.getAxes()) {
723+
// The ManualComputationOp "owns" this initial edge.
724+
operationToEdgesMap[manualComputationOp][axisRef].push_back(
725+
PropagationEdge{/*source=*/EdgeNode{EdgeNodeType::OPERAND, i},
726+
/*target=*/EdgeNode{EdgeNodeType::RESULT, i},
727+
/*propagationStep=*/0});
728+
}
729+
}
730+
i++;
731+
}
732+
});
733+
}
734+
669735
OriginSharding lookUpValueOriginSharding(
670736
Value value, AxisRefAttr axisRef,
671737
const ValueToOriginShardingMap& valueToOriginShardingMap) {
@@ -772,6 +838,8 @@ void SourceShardingHandler::prepareHandler(ModuleOp moduleOp) {
772838
}
773839
if (mappings->debugPropagationEdgeSharding) {
774840
prepareFuncResultToEdgesHandler(moduleOp, mappings->funcResultToEdgesMap);
841+
prepareShardingConstraintToEdgesHandler(moduleOp,
842+
mappings->operationToEdgesMap);
775843
}
776844
if (mappings->debugShardingOrigins ||
777845
mappings->debugPropagationEdgeSharding) {

shardy/dialect/sdy/transforms/propagation/debugging/test/edge_shardings.mlir

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ sdy.mesh @mesh = <["a"=2, "b"=2]>
167167
// CHECK-LABEL: manual_computation_multiple_results
168168
// CHECK-SAME: %arg0: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b", ?}, {"a", ?}]>})
169169
// CHECK-SAME: -> (tensor<16x32xf32> {sdy.propagation_edges = #sdy.propagation_edges<[
170-
// CHECK-SAME: {step-0 = [{"a" = operand-0 -> [result-0]}, {"b" = operand-0 -> [result-0]}]},
171-
// CHECK-SAME: {step-6 = [{"a" = operand-0 -> [result-0]}]}]>,
170+
// CHECK-SAME: {step-0 = [{"a" = operand-0 -> [result-0]}, {"b" = operand-0 -> [result-0]}]}]>,
172171
// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{?}, {"a", ?}]>},
173172
// CHECK-SAME: tensor<32x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b", ?}, {"a", ?}]>}) {
174173
func.func @manual_computation_multiple_results(%arg0: tensor<32x32xf32>) -> (tensor<16x32xf32>, tensor<32x32xf32>) {
@@ -184,6 +183,8 @@ func.func @manual_computation_multiple_results(%arg0: tensor<32x32xf32>) -> (ten
184183
// CHECK-SAME: #sdy.propagation_edges<[
185184
// CHECK-SAME: {step-1 = [{"b" = result-0 -> [operand-0]}]},
186185
// CHECK-SAME: {step-5 = [{"a" = result-0 -> [operand-0]}]}]>],
186+
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[
187+
// CHECK-SAME: {step-0 = [{"b" = operand-0 -> [result-0]}]}]>,
187188
// CHECK-SAME: sdy.result_propagation_edges = [
188189
// CHECK-SAME: #sdy.propagation_edges<[
189190
// CHECK-SAME: {step-3 = [{"a" = operand-0 -> [result-0]}]}]>,
@@ -250,20 +251,18 @@ func.func @sub_axes_merging_reshape(
250251

251252
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=8]>
252253

253-
// TODO(b/434949739): Describe how the propagation edge is created due to the
254-
// apply-sharding-constraints pass.
255254
// CHECK-LABEL: two_sharding_constraint
256255
// CHECK-SAME: %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}
257256
// CHECK-SAME: -> (tensor<8x8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"b" = operand-0 -> [result-0]}]}, {step-6 = [{"a" = operand-0 -> [result-0]}]}]>,
258257
// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}, {"b", ?}]>}) {
259258
func.func @two_sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
260259
// CHECK-NEXT: %[[SC_1:.*]] = sdy.sharding_constraint %arg0 <@mesh, [{"a"}, {"b", ?}]> {
261-
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[{step-1 = [{"a" = result-0 -> [operand-0]}]}, {step-5 = [{"b" = result-0 -> [operand-0]}]}]>} : tensor<8x8xf32>
260+
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"a" = result-0 -> [operand-0]}]}, {step-5 = [{"b" = result-0 -> [operand-0]}]}]>} : tensor<8x8xf32>
262261
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %[[SC_1]], %[[SC_1]] {
263-
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[{step-2 = [{"a" = operand-0 -> [result-0]}]}, {step-4 = [{"b" = result-0 -> [operand-0, operand-1]}]}]>,
262+
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[{step-2 = [{"a" = operand-0 -> [result-0]}]}, {step-4 = [{"b" = result-0 -> [operand-0, operand-1]}]}]>
264263
// CHECK-SAME: sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a", ?}, {"b", ?}]>]>} : tensor<8x8xf32>
265264
// CHECK-NEXT: %[[SC_2:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a", ?}, {"b"}]> {
266-
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[{step-3 = [{"a" = operand-0 -> [result-0]}, {"b" = result-0 -> [operand-0]}]}]>} : tensor<8x8xf32>
265+
// CHECK-SAME: sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"b" = result-0 -> [operand-0]}]}, {step-3 = [{"a" = operand-0 -> [result-0]}]}]>} : tensor<8x8xf32>
267266
// CHECK-NEXT: return %[[SC_2]]
268267
%0 = sdy.sharding_constraint %arg0 <@mesh, [{"a"}, {?}]> : tensor<8x8xf32>
269268
%1 = stablehlo.add %0, %0 : tensor<8x8xf32>
@@ -275,17 +274,17 @@ func.func @two_sharding_constraint(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> {
275274

276275
sdy.mesh @mesh = <["a"=2, "b"=2, "c"=8]>
277276

278-
// TODO(b/434949739): Describe how the propagation edge is created due to the
279-
// apply-sharding-constraints pass.
280277
// CHECK-LABEL: push_sharding_constraints_to_func_results
281278
// CHECK-SAME: %arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>},
282279
// CHECK-SAME: %arg1: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>})
283-
// CHECK-SAME: -> (tensor<8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"a" = operand-0 -> [result-0]}]}, {step-1 = [{"a" = operand-0 -> [result-0]}]}]>,
280+
// CHECK-SAME: -> (tensor<8xf32> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"a" = operand-0 -> [result-0]}]}]>,
284281
// CHECK-SAME: sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>},
285282
// CHECK-SAME: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a", ?}]>}) {
286283
func.func @push_sharding_constraints_to_func_results(
287284
%arg0: tensor<8xf32>, %arg1: tensor<8xf32>
288285
) -> (tensor<8xf32>, tensor<8xf32>) {
286+
// CHECK: %[[C1:.*]] = sdy.sharding_constraint %arg0 <@mesh, [{"a"}]> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"a" = result-0 -> [operand-0]}]}]>} : tensor<8xf32>
287+
// CHECK: %[[C2:.*]] = sdy.sharding_constraint %arg1 <@mesh, [{"a"}]> {sdy.propagation_edges = #sdy.propagation_edges<[{step-0 = [{"a" = result-0 -> [operand-0]}]}]>} : tensor<8xf32>
289288
%1 = sdy.sharding_constraint %arg0 <@mesh, [{"a"}]> : tensor<8xf32>
290289
%2 = sdy.sharding_constraint %arg1 <@mesh, [{"a"}]> : tensor<8xf32>
291290
return %1, %2 : tensor<8xf32>, tensor<8xf32>

0 commit comments

Comments
 (0)