Skip to content

Commit 89f7cc9

Browse files
Refactor to create mesh only once and pass to methods of shardy partitioner.
It is not a pure refactoring. It now supports the case the meshes are the same but with different device order also on the default/minimal version. PiperOrigin-RevId: 800425531
1 parent 8ca89c5 commit 89f7cc9

File tree

3 files changed

+71
-66
lines changed

3 files changed

+71
-66
lines changed

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 18 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
197197
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
198198
// different on device order.
199199
void insertExplicitReshards(Operation* op,
200-
const SmallVector<TensorShardingAttr>& inShardings,
201-
const SmallVector<TensorShardingAttr>& outShardings,
200+
ArrayRef<TensorShardingAttr> inShardings,
201+
ArrayRef<TensorShardingAttr> outShardings,
202202
const ShardingProjection& shardingProjection,
203203
UpdateTensorShardings updateTensorShardings,
204204
IRRewriter& rewriter,
@@ -739,10 +739,8 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
739739
}
740740

741741
TensorShardingAttr getShardingOfTensorIndex(
742-
const int64_t tensorIndex,
743-
const SmallVector<TensorShardingAttr>& inShardings,
744-
const SmallVector<TensorShardingAttr>& outShardings,
745-
const int64_t numOperands) {
742+
const int64_t tensorIndex, ArrayRef<TensorShardingAttr> inShardings,
743+
ArrayRef<TensorShardingAttr> outShardings, const int64_t numOperands) {
746744
return tensorIndex < numOperands ? inShardings[tensorIndex]
747745
: outShardings[tensorIndex - numOperands];
748746
}
@@ -763,8 +761,8 @@ Mesh getMeshOrDefault(TensorShardingAttr sharding,
763761
// 2. Both tensors have the same mesh but may have different device orders.
764762
// 3. The factor shardings are not compatible.
765763
AxesPerFactorWithMesh findCommonAxesOnUnaryOperation(
766-
const SmallVector<TensorShardingAttr>& inShardings,
767-
const SmallVector<TensorShardingAttr>& outShardings,
764+
ArrayRef<TensorShardingAttr> inShardings,
765+
ArrayRef<TensorShardingAttr> outShardings,
768766
const ShardingProjection& shardingProjection,
769767
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
770768
const SymbolTable& symbolTable, const Mesh& mesh) {
@@ -858,8 +856,8 @@ void distributeAxisRefsToBatchingFactors(
858856
}
859857
}
860858

861-
Mesh getMostCommonMesh(const SmallVector<TensorShardingAttr>& inShardings,
862-
const SmallVector<TensorShardingAttr>& outShardings,
859+
Mesh getMostCommonMesh(ArrayRef<TensorShardingAttr> inShardings,
860+
ArrayRef<TensorShardingAttr> outShardings,
863861
OpShardingRuleAttr shardingRule,
864862
const SymbolTable& symbolTable,
865863
const Mesh& defaultMesh) {
@@ -882,8 +880,8 @@ Mesh getMostCommonMesh(const SmallVector<TensorShardingAttr>& inShardings,
882880
}
883881

884882
AxesPerFactorWithMesh findCommonAxes(
885-
const SmallVector<TensorShardingAttr>& inShardings,
886-
const SmallVector<TensorShardingAttr>& outShardings,
883+
ArrayRef<TensorShardingAttr> inShardings,
884+
ArrayRef<TensorShardingAttr> outShardings,
887885
const ShardingProjection& shardingProjection,
888886
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
889887
const SymbolTable& symbolTable, const Mesh& defaultMesh) {
@@ -1009,43 +1007,24 @@ bool differentOperandShardingFromFirstResult(Operation* op) {
10091007
});
10101008
}
10111009

1012-
void insertExplicitReshardsOnOp(Operation* op, IRRewriter& rewriter,
1013-
const SymbolTable& symbolTable,
1014-
OpShardingRuleAttr shardingRule,
1015-
const bool onFullVersion) {
1010+
void insertExplicitReshardsOnOp(
1011+
Operation* op, ArrayRef<TensorShardingAttr> inShardings,
1012+
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
1013+
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
1014+
const bool onFullVersion, const Mesh& defaultMesh) {
10161015
if (!onFullVersion) {
10171016
return;
10181017
}
10191018

1020-
SmallVector<int64_t> tensorSizes = getTensorSizes(op);
1021-
SmallVector<TensorShardingAttr> inShardings = getShardings(op->getOperands());
1022-
SmallVector<TensorShardingAttr> outShardings = getShardings(op->getResults());
1023-
std::optional<StringRef> meshName = getCommonMeshName(
1024-
inShardings, outShardings, symbolTable, /*ignoreDeviceIds=*/true);
1025-
if (!meshName.has_value()) {
1026-
// This means none of the operands or results have a sharding attribute or
1027-
// the sharding attributes use different meshes. Skip if so.
1028-
// TODO(enver): Actually, we are moving towards supporting multiple explicit
1029-
// reshards so operands and results are all bound by the same mesh.
1030-
return;
1031-
}
1032-
1033-
Mesh defaultMesh(getMeshAttr(symbolTable, *meshName), *meshName);
1034-
assert(defaultMesh.attr() && "unknown mesh");
1035-
// TODO(enver): Support maximal meshes.
1036-
if (defaultMesh.attr().isMaximal()) {
1037-
return;
1038-
}
1039-
10401019
ShardingProjection shardingProjection = ShardingProjection::build(
10411020
inShardings, outShardings, shardingRule, defaultMesh.attr(),
10421021
/*closedIfMissing=*/true);
10431022

10441023
UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
10451024
shardingRule.getNumResults());
1046-
AxesPerFactorWithMesh commonAxesPerFactorWithMesh =
1047-
findCommonAxes(inShardings, outShardings, shardingProjection,
1048-
shardingRule, tensorSizes, symbolTable, defaultMesh);
1025+
AxesPerFactorWithMesh commonAxesPerFactorWithMesh = findCommonAxes(
1026+
inShardings, outShardings, shardingProjection, shardingRule,
1027+
getTensorSizes(op), symbolTable, defaultMesh);
10491028
if (commonAxesPerFactorWithMesh.empty()) {
10501029
return;
10511030
}

shardy/dialect/sdy/transforms/export/explicit_reshards_util.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,13 @@ bool differentOperandShardingFromFirstResult(Operation* op);
9191
// sharding of `op` is compatible with its sharding rule.
9292
//
9393
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
94-
void insertExplicitReshardsOnOp(Operation* op, IRRewriter& rewriter,
94+
void insertExplicitReshardsOnOp(Operation* op,
95+
ArrayRef<TensorShardingAttr> inShardings,
96+
ArrayRef<TensorShardingAttr> outShardings,
97+
IRRewriter& rewriter,
9598
const SymbolTable& symbolTable,
9699
OpShardingRuleAttr shardingRule,
97-
bool onFullVersion);
100+
bool onFullVersion, const Mesh& defaultMesh);
98101

99102
} // namespace sdy
100103
} // namespace mlir

shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc

Lines changed: 48 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -153,29 +153,17 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
153153
// return %reshard : tensor<4x8xf32>
154154
// ```
155155
template <class OpTy>
156-
void processDot(OpTy op, IRRewriter& rewriter, const SymbolTable& symbolTable,
157-
OpShardingRuleAttr shardingRule) {
158-
SmallVector<TensorShardingAttr> inShardingAttrs =
159-
getShardings(op.getOperands());
160-
ArrayRef<TensorShardingAttr> outShardingAttrs =
161-
getShardings(op.getOperation());
162-
if (outShardingAttrs.empty()) {
156+
void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
157+
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
158+
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
159+
const Mesh& mesh) {
160+
if (outShardings.empty()) {
163161
// Result doesn't have a sharding.
164162
return;
165163
}
166-
std::optional<StringRef> meshName =
167-
getCommonMeshName(inShardingAttrs, outShardingAttrs, symbolTable,
168-
/*ignoreDeviceIds=*/false);
169-
if (!meshName.has_value()) {
170-
// This means none of the operands or results have a sharding attribute
171-
// or the sharding attributes use different meshes. Skip if so.
172-
return;
173-
}
174-
MeshAttr mesh = getMeshAttr(symbolTable, meshName.value());
175-
assert(mesh && "unknown mesh");
176164
ShardingProjection shardingProjection =
177-
ShardingProjection::build(inShardingAttrs, outShardingAttrs, shardingRule,
178-
mesh, /*closedIfMissing=*/true);
165+
ShardingProjection::build(inShardings, outShardings, shardingRule,
166+
mesh.attr(), /*closedIfMissing=*/true);
179167

180168
const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0);
181169
const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1);
@@ -263,13 +251,34 @@ void processDot(OpTy op, IRRewriter& rewriter, const SymbolTable& symbolTable,
263251
setSharding(op.getResult(),
264252
resultSharding.createTensorShardingAttr(
265253
op.getContext(), shardingRule.getResultMapping(0),
266-
shardingRule.getFactorSizes(), meshName.value(), mesh));
254+
shardingRule.getFactorSizes(), mesh.name(), mesh.attr()));
267255
rewriter.setInsertionPointAfter(op);
268256
auto reshardOp = rewriter.create<ReshardOp>(op.getLoc(), op.getResult(),
269-
outShardingAttrs.front());
257+
outShardings.front());
270258
rewriter.replaceAllUsesExcept(op.getResult(), reshardOp, reshardOp);
271259
}
272260

261+
std::optional<Mesh> getMesh(ArrayRef<TensorShardingAttr> inShardings,
262+
ArrayRef<TensorShardingAttr> outShardings,
263+
const SymbolTable& symbolTable) {
264+
std::optional<StringRef> meshName = getCommonMeshName(
265+
inShardings, outShardings, symbolTable, /*ignoreDeviceIds=*/true);
266+
if (!meshName.has_value()) {
267+
// This means none of the operands or results have a sharding attribute or
268+
// the sharding attributes use different meshes.
269+
// TODO(enver): Actually, we are moving towards supporting multiple
270+
// explicit reshards so operands and results are all bound by the same
271+
// mesh.
272+
return std::nullopt;
273+
}
274+
MeshAttr meshAttr = getMeshAttr(symbolTable, *meshName);
275+
assert(meshAttr && "unknown mesh");
276+
if (meshAttr.isMaximal()) {
277+
return std::nullopt;
278+
}
279+
return Mesh(meshAttr, *meshName);
280+
}
281+
273282
struct InsertExplicitReshardsPass
274283
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
275284
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
@@ -338,20 +347,34 @@ struct InsertExplicitReshardsPass
338347
return;
339348
}
340349

350+
SmallVector<TensorShardingAttr> inShardings =
351+
getShardings(op->getOperands());
352+
SmallVector<TensorShardingAttr> outShardings =
353+
getShardings(op->getResults());
354+
355+
std::optional<Mesh> mesh =
356+
getMesh(inShardings, outShardings, symbolTable);
357+
if (!mesh.has_value()) {
358+
return;
359+
}
360+
341361
if (!onFullVersion) {
342362
TypeSwitch<Operation*>(op)
343363
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
344-
processDot(dotOp, rewriter, symbolTable, shardingRule);
364+
processDot(dotOp, inShardings, outShardings, rewriter,
365+
symbolTable, shardingRule, *mesh);
345366
})
346367
.Case<stablehlo::DotGeneralOp>(
347368
[&](stablehlo::DotGeneralOp dotGeneralOp) {
348-
processDot(dotGeneralOp, rewriter, symbolTable, shardingRule);
369+
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
370+
symbolTable, shardingRule, *mesh);
349371
});
350372
return;
351373
}
352374

353-
insertExplicitReshardsOnOp(op, rewriter, symbolTable, shardingRule,
354-
onFullVersion);
375+
insertExplicitReshardsOnOp(op, inShardings, outShardings, rewriter,
376+
symbolTable, shardingRule, onFullVersion,
377+
*mesh);
355378

356379
// TODO(enver): Remove sharding rules from ops.
357380
});

0 commit comments

Comments
 (0)