Skip to content

Commit 984aecf

Browse files
Refactor to simplify mesh extraction in getMeshOrDefault.
It is supposed to be pure. PiperOrigin-RevId: 800907631
1 parent 08d0de3 commit 984aecf

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,10 +751,9 @@ Mesh getMeshOrDefault(TensorShardingAttr sharding,
751751
if (!sharding) {
752752
return defaultMesh;
753753
}
754-
StringRef meshName =
755-
getCommonMeshName({sharding}, {}, symbolTable, /*ignoreDeviceIds=*/false)
756-
.value();
757-
return Mesh(getMeshAttr(symbolTable, meshName), meshName);
754+
// NOTE: sharding always has a meshOrRef because it is a required parameter.
755+
return Mesh(sharding.getMesh(symbolTable),
756+
cast<FlatSymbolRefAttr>(sharding.getMeshOrRef()).getValue());
758757
}
759758

760759
// Assumes that:

0 commit comments

Comments
 (0)