@@ -864,7 +864,7 @@ def __init__(
864864 raise ValueError (
865865 "ManagedDeviceMesh doesn't support both mesh and parent are None."
866866 )
867- self .mesh = mesh
867+ self ._mesh = mesh
868868 self .mesh_dim_names = mesh_dim_names
869869 self .replicate_pg = replicate_pg
870870 self .replicate_dim = replicate_dim
@@ -893,17 +893,17 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
893893 elif mesh_dim_names in self .flatten_meshes :
894894 return self .flatten_meshes [mesh_dim_names ]
895895 else :
896- assert self .mesh is not None
897- return self .mesh [mesh_dim_names ]
896+ assert self ._mesh is not None
897+ return self ._mesh [mesh_dim_names ]
898898 else :
899899 assert isinstance (mesh_dim_names , tuple )
900900 if self .replicate_dim_name in mesh_dim_names :
901- assert self .mesh is not None
902- return self .mesh [mesh_dim_names ]
901+ assert self ._mesh is not None
902+ return self ._mesh [mesh_dim_names ]
903903 else :
904- assert self .mesh is not None
904+ assert self ._mesh is not None
905905 return ManagedDeviceMesh (
906- self .mesh [mesh_dim_names ],
906+ self ._mesh [mesh_dim_names ],
907907 mesh_dim_names ,
908908 self .replicate_pg ,
909909 mesh_dim_names .index (self .replicate_dim_name ),
@@ -924,8 +924,8 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
924924 elif dim == self .replicate_dim :
925925 return self .replicate_pg
926926 else :
927- assert self .mesh is not None
928- return self .mesh .get_group (self ._real_mesh_dim (dim ))
927+ assert self ._mesh is not None
928+ return self ._mesh .get_group (self ._real_mesh_dim (dim ))
929929
930930 def _flatten (self , mesh_dim_name : Optional [str ]) -> "DeviceMesh" :
931931 flatten_mesh = _FlattenDeviceMesh (self )
@@ -939,32 +939,32 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
939939
940940 def size (self , mesh_dim : Optional [int ] = None ) -> int :
941941 if mesh_dim is None :
942- if self .mesh is None :
942+ if self ._mesh is None :
943943 return self .replicate_pg .size ()
944944 else :
945- assert self .mesh is not None
946- return self .mesh .size () * self .replicate_pg .size ()
945+ assert self ._mesh is not None
946+ return self ._mesh .size () * self .replicate_pg .size ()
947947 elif mesh_dim == self .replicate_dim :
948948 return self .replicate_pg .size ()
949949 else :
950- assert self .mesh is not None
951- return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
950+ assert self ._mesh is not None
951+ return self ._mesh .size (self ._real_mesh_dim (mesh_dim ))
952952
953953 @property
954954 def ndim (self ) -> int :
955- assert self .mesh is not None
956- return self .mesh .ndim + 1
955+ assert self ._mesh is not None
956+ return self ._mesh .ndim + 1
957957
958958 @property
959959 def shape (self ) -> Tuple [int , ...]:
960- assert self .mesh is not None
961- ret : List [int ] = list (self .mesh .shape )
960+ assert self ._mesh is not None
961+ ret : List [int ] = list (self ._mesh .shape )
962962 ret .insert (self .replicate_dim , self .replicate_pg .size ())
963963 return tuple (ret )
964964
965965 def get_rank (self ) -> int :
966- assert self .mesh is not None
967- return self .mesh .get_rank ()
966+ assert self ._mesh is not None
967+ return self ._mesh .get_rank ()
968968
969969 def get_local_rank (self , mesh_dim : Optional [Union [int , str ]] = None ) -> int :
970970 if isinstance (mesh_dim , str ):
@@ -973,33 +973,37 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
973973 dim = 0 if mesh_dim is None else int (mesh_dim )
974974
975975 if mesh_dim is None :
976- if self .mesh is None :
976+ if self ._mesh is None :
977977 return get_rank (self .replicate_pg )
978978
979979 assert self .replicate_dim == 0 , "replicate_dim must be the first one"
980- assert self .mesh is not None
981- other_dim_size = self .mesh .size ()
982- assert self .mesh is not None
983- other_dim_rank = self .mesh .get_local_rank ()
980+ assert self ._mesh is not None
981+ other_dim_size = self ._mesh .size ()
982+ assert self ._mesh is not None
983+ other_dim_rank = self ._mesh .get_local_rank ()
984984 replicate_pg_rank = get_rank (self .replicate_pg )
985985 return other_dim_size * replicate_pg_rank + other_dim_rank
986986 elif dim == self .replicate_dim :
987987 return get_rank (self .replicate_pg )
988988 else :
989- assert self .mesh is not None
990- return self .mesh .get_local_rank (self ._real_mesh_dim (dim ))
989+ assert self ._mesh is not None
990+ return self ._mesh .get_local_rank (self ._real_mesh_dim (dim ))
991991
992992 def get_coordinate (self ) -> Optional [List [int ]]:
993993 """
994994 Return the relative indices of this rank relative to all
995995 dimensions of the mesh. If this rank is not part of the mesh, return None.
996996 """
997- assert self .mesh is not None
998- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
997+ assert self ._mesh is not None
998+ return self ._mesh ._coordinate_on_dim if self ._mesh ._coordinate_on_dim else None
999999
10001000 def get_all_groups (self ) -> List [BaseProcessGroup ]:
10011001 raise NotImplementedError
10021002
1003+ @property
1004+ def mesh (self ):
1005+ return self ._mesh .mesh
1006+
10031007
10041008class _FlattenDeviceMesh (DeviceMesh ):
10051009 def __init__ (self , managed_mesh : ManagedDeviceMesh ) -> None :
0 commit comments