@@ -197,8 +197,8 @@ bool shouldReshardToCommonMesh(TensorShardingAttr sharding, const Mesh& mesh,
197
197
// Assumes all tensor shardings have the same mesh as `mesh` on axes but may be
198
198
// different on device order.
199
199
void insertExplicitReshards (Operation* op,
200
- const SmallVector <TensorShardingAttr>& inShardings,
201
- const SmallVector <TensorShardingAttr>& outShardings,
200
+ ArrayRef <TensorShardingAttr> inShardings,
201
+ ArrayRef <TensorShardingAttr> outShardings,
202
202
const ShardingProjection& shardingProjection,
203
203
UpdateTensorShardings updateTensorShardings,
204
204
IRRewriter& rewriter,
@@ -739,10 +739,8 @@ std::optional<int64_t> findTensorIndexToPreferOnUnaryOperation(
739
739
}
740
740
741
741
TensorShardingAttr getShardingOfTensorIndex (
742
- const int64_t tensorIndex,
743
- const SmallVector<TensorShardingAttr>& inShardings,
744
- const SmallVector<TensorShardingAttr>& outShardings,
745
- const int64_t numOperands) {
742
+ const int64_t tensorIndex, ArrayRef<TensorShardingAttr> inShardings,
743
+ ArrayRef<TensorShardingAttr> outShardings, const int64_t numOperands) {
746
744
return tensorIndex < numOperands ? inShardings[tensorIndex]
747
745
: outShardings[tensorIndex - numOperands];
748
746
}
@@ -763,8 +761,8 @@ Mesh getMeshOrDefault(TensorShardingAttr sharding,
763
761
// 2. Both tensors have the same mesh but may have different device orders.
764
762
// 3. The factor shardings are not compatible.
765
763
AxesPerFactorWithMesh findCommonAxesOnUnaryOperation (
766
- const SmallVector <TensorShardingAttr>& inShardings,
767
- const SmallVector <TensorShardingAttr>& outShardings,
764
+ ArrayRef <TensorShardingAttr> inShardings,
765
+ ArrayRef <TensorShardingAttr> outShardings,
768
766
const ShardingProjection& shardingProjection,
769
767
OpShardingRuleAttr shardingRule, ArrayRef<int64_t > tensorSizes,
770
768
const SymbolTable& symbolTable, const Mesh& mesh) {
@@ -858,8 +856,8 @@ void distributeAxisRefsToBatchingFactors(
858
856
}
859
857
}
860
858
861
- Mesh getMostCommonMesh (const SmallVector <TensorShardingAttr>& inShardings,
862
- const SmallVector <TensorShardingAttr>& outShardings,
859
+ Mesh getMostCommonMesh (ArrayRef <TensorShardingAttr> inShardings,
860
+ ArrayRef <TensorShardingAttr> outShardings,
863
861
OpShardingRuleAttr shardingRule,
864
862
const SymbolTable& symbolTable,
865
863
const Mesh& defaultMesh) {
@@ -882,8 +880,8 @@ Mesh getMostCommonMesh(const SmallVector<TensorShardingAttr>& inShardings,
882
880
}
883
881
884
882
AxesPerFactorWithMesh findCommonAxes (
885
- const SmallVector <TensorShardingAttr>& inShardings,
886
- const SmallVector <TensorShardingAttr>& outShardings,
883
+ ArrayRef <TensorShardingAttr> inShardings,
884
+ ArrayRef <TensorShardingAttr> outShardings,
887
885
const ShardingProjection& shardingProjection,
888
886
OpShardingRuleAttr shardingRule, ArrayRef<int64_t > tensorSizes,
889
887
const SymbolTable& symbolTable, const Mesh& defaultMesh) {
@@ -1009,43 +1007,24 @@ bool differentOperandShardingFromFirstResult(Operation* op) {
1009
1007
});
1010
1008
}
1011
1009
1012
- void insertExplicitReshardsOnOp (Operation* op, IRRewriter& rewriter,
1013
- const SymbolTable& symbolTable,
1014
- OpShardingRuleAttr shardingRule,
1015
- const bool onFullVersion) {
1010
+ void insertExplicitReshardsOnOp (
1011
+ Operation* op, ArrayRef<TensorShardingAttr> inShardings,
1012
+ ArrayRef<TensorShardingAttr> outShardings, IRRewriter& rewriter,
1013
+ const SymbolTable& symbolTable, OpShardingRuleAttr shardingRule,
1014
+ const bool onFullVersion, const Mesh& defaultMesh) {
1016
1015
if (!onFullVersion) {
1017
1016
return ;
1018
1017
}
1019
1018
1020
- SmallVector<int64_t > tensorSizes = getTensorSizes (op);
1021
- SmallVector<TensorShardingAttr> inShardings = getShardings (op->getOperands ());
1022
- SmallVector<TensorShardingAttr> outShardings = getShardings (op->getResults ());
1023
- std::optional<StringRef> meshName = getCommonMeshName (
1024
- inShardings, outShardings, symbolTable, /* ignoreDeviceIds=*/ true );
1025
- if (!meshName.has_value ()) {
1026
- // This means none of the operands or results have a sharding attribute or
1027
- // the sharding attributes use different meshes. Skip if so.
1028
- // TODO(enver): Actually, we are moving towards supporting multiple explicit
1029
- // reshards so operands and results are all bound by the same mesh.
1030
- return ;
1031
- }
1032
-
1033
- Mesh defaultMesh (getMeshAttr (symbolTable, *meshName), *meshName);
1034
- assert (defaultMesh.attr () && " unknown mesh" );
1035
- // TODO(enver): Support maximal meshes.
1036
- if (defaultMesh.attr ().isMaximal ()) {
1037
- return ;
1038
- }
1039
-
1040
1019
ShardingProjection shardingProjection = ShardingProjection::build (
1041
1020
inShardings, outShardings, shardingRule, defaultMesh.attr (),
1042
1021
/* closedIfMissing=*/ true );
1043
1022
1044
1023
UpdateTensorShardings updateTensorShardings (shardingRule.getNumOperands (),
1045
1024
shardingRule.getNumResults ());
1046
- AxesPerFactorWithMesh commonAxesPerFactorWithMesh =
1047
- findCommonAxes ( inShardings, outShardings, shardingProjection,
1048
- shardingRule, tensorSizes , symbolTable, defaultMesh);
1025
+ AxesPerFactorWithMesh commonAxesPerFactorWithMesh = findCommonAxes (
1026
+ inShardings, outShardings, shardingProjection, shardingRule ,
1027
+ getTensorSizes (op) , symbolTable, defaultMesh);
1049
1028
if (commonAxesPerFactorWithMesh.empty ()) {
1050
1029
return ;
1051
1030
}
0 commit comments