Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 18 additions & 39 deletions shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
// different on device order.
void insertExplicitReshards(Operation* op,
const SmallVector<TensorShardingAttr>& inShardings,
const SmallVector<TensorShardingAttr>& outShardings,
ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
const ShardingProjection& shardingProjection,
UpdateTensorShardings updateTensorShardings,
IRRewriter& rewriter,
Expand Down Expand Up @@ -739,10 +739,8 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
}

TensorShardingAttr getShardingOfTensorIndex(
const int64_t tensorIndex,
const SmallVector<TensorShardingAttr>& inShardings,
const SmallVector<TensorShardingAttr>& outShardings,
const int64_t numOperands) {
const int64_t tensorIndex, ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings, const int64_t numOperands) {
return tensorIndex < numOperands ? inShardings[tensorIndex]
: outShardings[tensorIndex - numOperands];
}
Expand All @@ -763,8 +761,8 @@ Mesh getMeshOrDefault(TensorShardingAttr sharding,
// 2. Both tensors have the same mesh but may have different device orders.
// 3. The factor shardings are not compatible.
AxesPerFactorWithMesh findCommonAxesOnUnaryOperation(
const SmallVector<TensorShardingAttr>& inShardings,
const SmallVector<TensorShardingAttr>& outShardings,
ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
const SymbolTable& symbolTable, const Mesh& mesh) {
Expand Down Expand Up @@ -858,8 +856,8 @@ void distributeAxisRefsToBatchingFactors(
}
}

Mesh getMostCommonMesh(const SmallVector<TensorShardingAttr>& inShardings,
const SmallVector<TensorShardingAttr>& outShardings,
Mesh getMostCommonMesh(ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
OpShardingRuleAttr shardingRule,
const SymbolTable& symbolTable,
const Mesh& defaultMesh) {
Expand All @@ -882,8 +880,8 @@ Mesh getMostCommonMesh(const SmallVector<TensorShardingAttr>& inShardings,
}

AxesPerFactorWithMesh findCommonAxes(
const SmallVector<TensorShardingAttr>& inShardings,
const SmallVector<TensorShardingAttr>& outShardings,
ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
const ShardingProjection& shardingProjection,
OpShardingRuleAttr shardingRule, ArrayRef<int64_t> tensorSizes,
const SymbolTable& symbolTable, const Mesh& defaultMesh) {
Expand Down Expand Up @@ -1009,43 +1007,24 @@ bool differentOperandShardingFromFirstResult(Operation* op) {
});
}

void insertExplicitReshardsOnOp(Operation* op, IRRewriter& rewriter,
const SymbolTable& symbolTable,
OpShardingRuleAttr shardingRule,
const bool onFullVersion) {
void insertExplicitReshardsOnOp(
Operation* op, ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
const bool onFullVersion, const Mesh& defaultMesh) {
if (!onFullVersion) {
return;
}

SmallVector<int64_t> tensorSizes = getTensorSizes(op);
SmallVector<TensorShardingAttr> inShardings = getShardings(op->getOperands());
SmallVector<TensorShardingAttr> outShardings = getShardings(op->getResults());
std::optional<StringRef> meshName = getCommonMeshName(
inShardings, outShardings, symbolTable, /*ignoreDeviceIds=*/true);
if (!meshName.has_value()) {
// This means none of the operands or results have a sharding attribute or
// the sharding attributes use different meshes. Skip if so.
// TODO(enver): Actually, we are moving towards supporting multiple explicit
// reshards so operands and results are all bound by the same mesh.
return;
}

Mesh defaultMesh(getMeshAttr(symbolTable, *meshName), *meshName);
assert(defaultMesh.attr() && "unknown mesh");
// TODO(enver): Support maximal meshes.
if (defaultMesh.attr().isMaximal()) {
return;
}

ShardingProjection shardingProjection = ShardingProjection::build(
inShardings, outShardings, shardingRule, defaultMesh.attr(),
/*closedIfMissing=*/true);

UpdateTensorShardings updateTensorShardings(shardingRule.getNumOperands(),
shardingRule.getNumResults());
AxesPerFactorWithMesh commonAxesPerFactorWithMesh =
findCommonAxes(inShardings, outShardings, shardingProjection,
shardingRule, tensorSizes, symbolTable, defaultMesh);
AxesPerFactorWithMesh commonAxesPerFactorWithMesh = findCommonAxes(
inShardings, outShardings, shardingProjection, shardingRule,
getTensorSizes(op), symbolTable, defaultMesh);
if (commonAxesPerFactorWithMesh.empty()) {
return;
}
Expand Down
7 changes: 5 additions & 2 deletions shardy/dialect/sdy/transforms/export/explicit_reshards_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ bool differentOperandShardingFromFirstResult(Operation* op);
// sharding of `op` is compatible with its sharding rule.
//
// Refer to the documentation of `InsertExplicitReshardsPass` for more details.
void insertExplicitReshardsOnOp(Operation* op, IRRewriter& rewriter,
void insertExplicitReshardsOnOp(Operation* op,
ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
IRRewriter& rewriter,
const SymbolTable& symbolTable,
OpShardingRuleAttr shardingRule,
bool onFullVersion);
bool onFullVersion, const Mesh& defaultMesh);

} // namespace sdy
} // namespace mlir
Expand Down
73 changes: 48 additions & 25 deletions shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,29 +153,17 @@ void insertExplicitReshardsOnDataFlowOp(ShardableDataFlowOpInterface& op,
// return %reshard : tensor<4x8xf32>
// ```
template <class OpTy>
void processDot(OpTy op, IRRewriter& rewriter, const SymbolTable& symbolTable,
OpShardingRuleAttr shardingRule) {
SmallVector<TensorShardingAttr> inShardingAttrs =
getShardings(op.getOperands());
ArrayRef<TensorShardingAttr> outShardingAttrs =
getShardings(op.getOperation());
if (outShardingAttrs.empty()) {
void processDot(OpTy op, ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
const Mesh& mesh) {
if (outShardings.empty()) {
// Result doesn't have a sharding.
return;
}
std::optional<StringRef> meshName =
getCommonMeshName(inShardingAttrs, outShardingAttrs, symbolTable,
/*ignoreDeviceIds=*/false);
if (!meshName.has_value()) {
// This means none of the operands or results have a sharding attribute
// or the sharding attributes use different meshes. Skip if so.
return;
}
MeshAttr mesh = getMeshAttr(symbolTable, meshName.value());
assert(mesh && "unknown mesh");
ShardingProjection shardingProjection =
ShardingProjection::build(inShardingAttrs, outShardingAttrs, shardingRule,
mesh, /*closedIfMissing=*/true);
ShardingProjection::build(inShardings, outShardings, shardingRule,
mesh.attr(), /*closedIfMissing=*/true);

const TensorFactorShardings& lhsSharding = shardingProjection.getOperand(0);
const TensorFactorShardings& rhsSharding = shardingProjection.getOperand(1);
Expand Down Expand Up @@ -263,13 +251,34 @@ void processDot(OpTy op, IRRewriter& rewriter, const SymbolTable& symbolTable,
setSharding(op.getResult(),
resultSharding.createTensorShardingAttr(
op.getContext(), shardingRule.getResultMapping(0),
shardingRule.getFactorSizes(), meshName.value(), mesh));
shardingRule.getFactorSizes(), mesh.name(), mesh.attr()));
rewriter.setInsertionPointAfter(op);
auto reshardOp = rewriter.create<ReshardOp>(op.getLoc(), op.getResult(),
outShardingAttrs.front());
outShardings.front());
rewriter.replaceAllUsesExcept(op.getResult(), reshardOp, reshardOp);
}

std::optional<Mesh> getMesh(ArrayRef<TensorShardingAttr> inShardings,
ArrayRef<TensorShardingAttr> outShardings,
const SymbolTable& symbolTable) {
std::optional<StringRef> meshName = getCommonMeshName(
inShardings, outShardings, symbolTable, /*ignoreDeviceIds=*/true);
if (!meshName.has_value()) {
// This means none of the operands or results have a sharding attribute or
// the sharding attributes use different meshes.
// TODO(enver): Actually, we are moving towards supporting multiple
// explicit reshards so operands and results are all bound by the same
// mesh.
return std::nullopt;
}
MeshAttr meshAttr = getMeshAttr(symbolTable, *meshName);
assert(meshAttr && "unknown mesh");
if (meshAttr.isMaximal()) {
return std::nullopt;
}
return Mesh(meshAttr, *meshName);
}

struct InsertExplicitReshardsPass
: public impl::InsertExplicitReshardsPassBase<InsertExplicitReshardsPass> {
using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase;
Expand Down Expand Up @@ -338,20 +347,34 @@ struct InsertExplicitReshardsPass
return;
}

SmallVector<TensorShardingAttr> inShardings =
getShardings(op->getOperands());
SmallVector<TensorShardingAttr> outShardings =
getShardings(op->getResults());

std::optional<Mesh> mesh =
getMesh(inShardings, outShardings, symbolTable);
if (!mesh.has_value()) {
return;
}

if (!onFullVersion) {
TypeSwitch<Operation*>(op)
.Case<stablehlo::DotOp>([&](stablehlo::DotOp dotOp) {
processDot(dotOp, rewriter, symbolTable, shardingRule);
processDot(dotOp, inShardings, outShardings, rewriter,
symbolTable, shardingRule, *mesh);
})
.Case<stablehlo::DotGeneralOp>(
[&](stablehlo::DotGeneralOp dotGeneralOp) {
processDot(dotGeneralOp, rewriter, symbolTable, shardingRule);
processDot(dotGeneralOp, inShardings, outShardings, rewriter,
symbolTable, shardingRule, *mesh);
});
return;
}

insertExplicitReshardsOnOp(op, rewriter, symbolTable, shardingRule,
onFullVersion);
insertExplicitReshardsOnOp(op, inShardings, outShardings, rewriter,
symbolTable, shardingRule, onFullVersion,
*mesh);

// TODO(enver): Remove sharding rules from ops.
});
Expand Down
Loading