@@ -361,46 +361,71 @@ PropagationEdgesAttr createPropagationEdges(Operation* op,
361
361
MLIRContext* context) {
362
362
Builder builder (context);
363
363
StepToAxisPropagationDetailsMap perStepEdgesForAxis;
364
+
365
+ // Build a temporary map to keep track of the edge with the minimum step
366
+ // index for each source/target/axis combination. Due to sharding constraints
367
+ // we have logic to specifically add edges outside of propagation. If similar
368
+ // edges are then re-included during propagation, we need to filter them out.
369
+ llvm::DenseMap<
370
+ AxisRefAttr,
371
+ llvm::DenseMap<EdgeValueRefAttr,
372
+ llvm::DenseMap<EdgeValueRefAttr, PropagationEdge>>>
373
+ minStepEdges;
364
374
for (const auto & [axisRef, edges] : axisToEdges) {
365
375
for (const PropagationEdge& edge : edges) {
366
- auto source =
376
+ auto sourceAttr =
367
377
EdgeValueRefAttr::get (context, edge.source .type , edge.source .index );
368
- auto target =
378
+ auto targetAttr =
369
379
EdgeValueRefAttr::get (context, edge.target .type , edge.target .index );
370
- perStepEdgesForAxis[edge.propagationStep ][axisRef][source].insert (target);
380
+ auto & sourceMap = minStepEdges[axisRef];
381
+ auto & targetMap = sourceMap[sourceAttr];
382
+ auto [it, inserted] = targetMap.try_emplace (targetAttr, edge);
383
+ if (!inserted) {
384
+ if (edge.propagationStep < it->second .propagationStep ) {
385
+ it->second = edge;
386
+ }
387
+ }
388
+ }
389
+ }
390
+
391
+ // Regroup the edges (with the minimum step index per Axis) by step index.
392
+ for (const auto & [axisRef, sourceMap] : minStepEdges) {
393
+ for (const auto & [sourceAttr, targetMap] : sourceMap) {
394
+ for (const auto & [targetAttr, edge] : targetMap) {
395
+ perStepEdgesForAxis[edge.propagationStep ][axisRef][sourceAttr].insert (
396
+ targetAttr);
397
+ }
371
398
}
372
399
}
373
400
374
401
SmallVector<PropagationOneStepAttr> perStepEdges;
375
402
for (const auto & [step, edgesForAxis] : perStepEdgesForAxis) {
376
- SmallVector<AxisToPropagationDetailsAttr> axis_entries ;
403
+ SmallVector<AxisToPropagationDetailsAttr> axisEntries ;
377
404
for (const auto & [axisRef, edges] : edgesForAxis) {
378
- // There should only be one source in the edge map.
379
- assert (edges.size () == 1 );
380
- EdgeValueRefAttr source = edges.begin ()->first ;
381
- DenseSet<EdgeValueRefAttr> targets = edges.begin ()->second ;
382
- // Sort the targets for deterministic ordering in the output attr.
383
- SmallVector<EdgeValueRefAttr> targetsArray (targets.begin (),
384
- targets.end ());
385
- llvm::stable_sort (targetsArray,
386
- [](EdgeValueRefAttr a, EdgeValueRefAttr b) {
387
- if (a.getType () == b.getType ()) {
388
- return a.getIndex () < b.getIndex ();
389
- }
390
- return a.getType () < b.getType ();
391
- });
392
- AxisToPropagationDetailsAttr axisToPropagationDetails =
393
- AxisToPropagationDetailsAttr::get (context, axisRef, source,
394
- targetsArray);
395
- axis_entries.push_back (axisToPropagationDetails);
405
+ for (const auto & [source, targets] : edges) {
406
+ // Sort the targets for deterministic ordering in the output attr.
407
+ SmallVector<EdgeValueRefAttr> targetsArray (targets.begin (),
408
+ targets.end ());
409
+ llvm::stable_sort (targetsArray,
410
+ [](EdgeValueRefAttr a, EdgeValueRefAttr b) {
411
+ if (a.getType () == b.getType ()) {
412
+ return a.getIndex () < b.getIndex ();
413
+ }
414
+ return a.getType () < b.getType ();
415
+ });
416
+ AxisToPropagationDetailsAttr axisToPropagationDetails =
417
+ AxisToPropagationDetailsAttr::get (context, axisRef, source,
418
+ targetsArray);
419
+ axisEntries.push_back (axisToPropagationDetails);
420
+ }
396
421
}
397
422
// Sort the axes by name for deterministic ordering in the output attr.
398
- llvm::stable_sort (axis_entries , [](AxisToPropagationDetailsAttr a,
399
- AxisToPropagationDetailsAttr b) {
423
+ llvm::stable_sort (axisEntries , [](AxisToPropagationDetailsAttr a,
424
+ AxisToPropagationDetailsAttr b) {
400
425
return a.getAxisName () < b.getAxisName ();
401
426
});
402
427
perStepEdges.push_back (
403
- PropagationOneStepAttr::get (context, step, axis_entries ));
428
+ PropagationOneStepAttr::get (context, step, axisEntries ));
404
429
}
405
430
406
431
// Sort the edges by step index.
@@ -666,6 +691,47 @@ void prepareFuncResultToEdgesHandler(
666
691
}
667
692
}
668
693
694
+ void prepareShardingConstraintToEdgesHandler (
695
+ ModuleOp moduleOp, OperationToEdgesMap& operationToEdgesMap) {
696
+ moduleOp.walk ([&](ShardingConstraintOp shardingConstraintOp) {
697
+ auto sharding = shardingConstraintOp.getSharding ();
698
+ if (!sharding) {
699
+ return ;
700
+ }
701
+ for (DimensionShardingAttr dimSharding : sharding.getDimShardings ()) {
702
+ for (AxisRefAttr axisRef : dimSharding.getAxes ()) {
703
+ operationToEdgesMap[shardingConstraintOp][axisRef].push_back (
704
+ PropagationEdge{/* source=*/ EdgeNode{EdgeNodeType::RESULT, 0 },
705
+ /* target=*/ EdgeNode{EdgeNodeType::OPERAND, 0 },
706
+ /* propagationStep=*/ 0 });
707
+ }
708
+ }
709
+ });
710
+
711
+ // Input sources of `ManualComputationOp` act as a sharding constraint (and
712
+ // the ApplyShardingConstrains pass treats them as such). Due to this, we need
713
+ // to create appropriate propagation edges for them.
714
+ moduleOp.walk ([&](ManualComputationOp manualComputationOp) {
715
+ int64_t i = 0 ;
716
+ for (const auto & sharding :
717
+ manualComputationOp.getInShardings ().getShardings ()) {
718
+ auto edgeOp =
719
+ DataFlowEdgeOp::lookup (manualComputationOp.getBody ().getArgument (i));
720
+ assert (edgeOp);
721
+ for (DimensionShardingAttr dimSharding : sharding.getDimShardings ()) {
722
+ for (AxisRefAttr axisRef : dimSharding.getAxes ()) {
723
+ // The ManualComputationOp "owns" this initial edge.
724
+ operationToEdgesMap[manualComputationOp][axisRef].push_back (
725
+ PropagationEdge{/* source=*/ EdgeNode{EdgeNodeType::OPERAND, i},
726
+ /* target=*/ EdgeNode{EdgeNodeType::RESULT, i},
727
+ /* propagationStep=*/ 0 });
728
+ }
729
+ }
730
+ i++;
731
+ }
732
+ });
733
+ }
734
+
669
735
OriginSharding lookUpValueOriginSharding (
670
736
Value value, AxisRefAttr axisRef,
671
737
const ValueToOriginShardingMap& valueToOriginShardingMap) {
@@ -772,6 +838,8 @@ void SourceShardingHandler::prepareHandler(ModuleOp moduleOp) {
772
838
}
773
839
if (mappings->debugPropagationEdgeSharding ) {
774
840
prepareFuncResultToEdgesHandler (moduleOp, mappings->funcResultToEdgesMap );
841
+ prepareShardingConstraintToEdgesHandler (moduleOp,
842
+ mappings->operationToEdgesMap );
775
843
}
776
844
if (mappings->debugShardingOrigins ||
777
845
mappings->debugPropagationEdgeSharding ) {
0 commit comments