From cc9fdee63da595854558d5513c02a4acc79af104 Mon Sep 17 00:00:00 2001 From: shardy authors Date: Mon, 1 Sep 2025 03:22:35 -0700 Subject: [PATCH] Refactor to create mesh only once and pass to methods of shardy partitioner. It is not a pure refactoring. It now supports the case the meshes are the same but with different device order also on the default/minimal version. PiperOrigin-RevId: 801741963 --- .../export/explicit_reshards_util.cc | 57 +++++---------- .../export/explicit_reshards_util.h | 7 +- .../export/insert_explicit_reshards.cc | 73 ++++++++++++------- 3 files changed, 71 insertions(+), 66 deletions(-) diff --git a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc index 90567a0c..0d112f2e 100644 --- a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc +++ b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc @@ -197,8 +197,8 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh, // Assumes all tensor shardings have the same mesh as `mesh` on axes but may be // different on device order. void insertExplicitReshards(Operation* op, - const SmallVector& inShardings, - const SmallVector& outShardings, + ArrayRef inShardings, + ArrayRef outShardings, const ShardingProjection& shardingProjection, UpdateTensorShardings updateTensorShardings, IRRewriter& rewriter, @@ -739,10 +739,8 @@ std::optional findTensorIndexToPreferOnUnaryOperation( } TensorShardingAttr getShardingOfTensorIndex( - const int64_t tensorIndex, - const SmallVector& inShardings, - const SmallVector& outShardings, - const int64_t numOperands) { + const int64_t tensorIndex, ArrayRef inShardings, + ArrayRef outShardings, const int64_t numOperands) { return tensorIndex < numOperands ? inShardings[tensorIndex] : outShardings[tensorIndex - numOperands]; } @@ -763,8 +761,8 @@ Mesh getMeshOrDefault(TensorShardingAttr sharding, // 2. Both tensors have the same mesh but may have different device orders. // 3. The factor shardings are not compatible. AxesPerFactorWithMesh findCommonAxesOnUnaryOperation( - const SmallVector& inShardings, - const SmallVector& outShardings, + ArrayRef inShardings, + ArrayRef outShardings, const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, const SymbolTable& symbolTable, const Mesh& mesh) { @@ -858,8 +856,8 @@ void distributeAxisRefsToBatchingFactors( } } -Mesh getMostCommonMesh(const SmallVector& inShardings, - const SmallVector& outShardings, +Mesh getMostCommonMesh(ArrayRef inShardings, + ArrayRef outShardings, OpShardingRuleAttr shardingRule, const SymbolTable& symbolTable, const Mesh& defaultMesh) { @@ -882,8 +880,8 @@ Mesh getMostCommonMesh(const SmallVector& inShardings, } AxesPerFactorWithMesh findCommonAxes( - const SmallVector& inShardings, - const SmallVector& outShardings, + ArrayRef inShardings, + ArrayRef outShardings, const ShardingProjection& shardingProjection, OpShardingRuleAttr shardingRule, ArrayRef tensorSizes, const SymbolTable& symbolTable, const Mesh& defaultMesh) { @@ -1009,43 +1007,24 @@ bool differentOperandShardingFromFirstResult(Operation* op) { }); } -void insertExplicitReshardsOnOp(Operation* op, IRRewriter& rewriter, - const SymbolTable& symbolTable, - OpShardingRuleAttr shardingRule, - const bool onFullVersion) { +void insertExplicitReshardsOnOp( + Operation* op, ArrayRef inShardings, + ArrayRef outShardings, IRRewriter& rewriter, + const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule, + const bool onFullVersion, const Mesh& defaultMesh) { if (!onFullVersion) { return; } - SmallVector tensorSizes = getTensorSizes(op); - SmallVector inShardings = getShardings(op->getOperands()); - SmallVector outShardings = getShardings(op->getResults()); - std::optional meshName = getCommonMeshName( - inShardings, outShardings, symbolTable, /*ignoreDeviceIds=*/true); - if (!meshName.has_value()) { - // This means none of the operands or results have a sharding attribute or - // the sharding attributes use different meshes. Skip if so. - // TODO(enver): Actually, we are moving towards supporting multiple explicit - // reshards so operands and results are all bound by the same mesh. - return; - } - - Mesh defaultMesh(getMeshAttr(symbolTable, *meshName), *meshName); - assert(defaultMesh.attr() && "unknown mesh"); - // TODO(enver): Support maximal meshes. - if (defaultMesh.attr().isMaximal()) { - return; - } - ShardingProjection shardingProjection = ShardingProjection::build( inShardings, outShardings, shardingRule, defaultMesh.attr(), /*closedIfMissing=*/true); UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(), shardingRule.getNumResults()); - AxesPerFactorWithMesh commonAxesPerFactorWithMesh = - findCommonAxes(inShardings, outShardings, shardingProjection, - shardingRule, tensorSizes, symbolTable, defaultMesh); + AxesPerFactorWithMesh commonAxesPerFactorWithMesh = findCommonAxes( + inShardings, outShardings, shardingProjection, shardingRule, + getTensorSizes(op), symbolTable, defaultMesh); if (commonAxesPerFactorWithMesh.empty()) { return; } diff --git a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.h b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.h index 00d33c7c..d1145713 100644 --- a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.h +++ b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.h @@ -91,10 +91,13 @@ bool differentOperandShardingFromFirstResult(Operation* op); // sharding of `op` is compatible with its sharding rule. // // Refer to the documentation of `InsertExplicitReshardsPass` for more details. -void insertExplicitReshardsOnOp(Operation* op, IRRewriter& rewriter, +void insertExplicitReshardsOnOp(Operation* op, + ArrayRef inShardings, + ArrayRef outShardings, + IRRewriter& rewriter, const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule, - bool onFullVersion); + bool onFullVersion, const Mesh& defaultMesh); } // namespace sdy } // namespace mlir diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc index 7cc314d2..16547c77 100644 --- a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -153,29 +153,17 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op, // return %reshard : tensor<4x8xf32> // ``` template -void processDot(OpTy op, IRRewriter& rewriter, const SymbolTable& symbolTable, - OpShardingRuleAttr shardingRule) { - SmallVector inShardingAttrs = - getShardings(op.getOperands()); - ArrayRef outShardingAttrs = - getShardings(op.getOperation()); - if (outShardingAttrs.empty()) { +void processDot(OpTy op, ArrayRef inShardings, + ArrayRef outShardings, IRRewriter& rewriter, + const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule, + const Mesh& mesh) { + if (outShardings.empty()) { // Result doesn't have a sharding. return; } - std::optional meshName = - getCommonMeshName(inShardingAttrs, outShardingAttrs, symbolTable, - /*ignoreDeviceIds=*/false); - if (!meshName.has_value()) { - // This means none of the operands or results have a sharding attribute - // or the sharding attributes use different meshes. Skip if so. - return; - } - MeshAttr mesh = getMeshAttr(symbolTable, meshName.value()); - assert(mesh && "unknown mesh"); ShardingProjection shardingProjection = - ShardingProjection::build(inShardingAttrs, outShardingAttrs, shardingRule, - mesh, /*closedIfMissing=*/true); + ShardingProjection::build(inShardings, outShardings, shardingRule, + mesh.attr(), /*closedIfMissing=*/true); const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0); const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1); @@ -263,13 +251,34 @@ void processDot(OpTy op, IRRewriter& rewriter, const SymbolTable& symbolTable, setSharding(op.getResult(), resultSharding.createTensorShardingAttr( op.getContext(), shardingRule.getResultMapping(0), - shardingRule.getFactorSizes(), meshName.value(), mesh)); + shardingRule.getFactorSizes(), mesh.name(), mesh.attr())); rewriter.setInsertionPointAfter(op); auto reshardOp = rewriter.create(op.getLoc(), op.getResult(), - outShardingAttrs.front()); + outShardings.front()); rewriter.replaceAllUsesExcept(op.getResult(), reshardOp, reshardOp); } +std::optional getMesh(ArrayRef inShardings, + ArrayRef outShardings, + const SymbolTable& symbolTable) { + std::optional meshName = getCommonMeshName( + inShardings, outShardings, symbolTable, /*ignoreDeviceIds=*/true); + if (!meshName.has_value()) { + // This means none of the operands or results have a sharding attribute or + // the sharding attributes use different meshes. + // TODO(enver): Actually, we are moving towards supporting multiple + // explicit reshards so operands and results are all bound by the same + // mesh. + return std::nullopt; + } + MeshAttr meshAttr = getMeshAttr(symbolTable, *meshName); + assert(meshAttr && "unknown mesh"); + if (meshAttr.isMaximal()) { + return std::nullopt; + } + return Mesh(meshAttr, *meshName); +} + struct InsertExplicitReshardsPass : public impl::InsertExplicitReshardsPassBase { using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase; @@ -338,20 +347,34 @@ struct InsertExplicitReshardsPass return; } + SmallVector inShardings = + getShardings(op->getOperands()); + SmallVector outShardings = + getShardings(op->getResults()); + + std::optional mesh = + getMesh(inShardings, outShardings, symbolTable); + if (!mesh.has_value()) { + return; + } + if (!onFullVersion) { TypeSwitch(op) .Case([&](stablehlo::DotOp dotOp) { - processDot(dotOp, rewriter, symbolTable, shardingRule); + processDot(dotOp, inShardings, outShardings, rewriter, + symbolTable, shardingRule, *mesh); }) .Case( [&](stablehlo::DotGeneralOp dotGeneralOp) { - processDot(dotGeneralOp, rewriter, symbolTable, shardingRule); + processDot(dotGeneralOp, inShardings, outShardings, rewriter, + symbolTable, shardingRule, *mesh); }); return; } - insertExplicitReshardsOnOp(op, rewriter, symbolTable, shardingRule, - onFullVersion); + insertExplicitReshardsOnOp(op, inShardings, outShardings, rewriter, + symbolTable, shardingRule, onFullVersion, + *mesh); // TODO(enver): Remove sharding rules from ops. });