@@ -287,8 +287,9 @@ class Distribution:
287
287
device_mesh: A `DeviceMesh` instance.
288
288
"""
289
289
290
- def __init__ (self , device_mesh ):
290
+ def __init__ (self , device_mesh , batch_dim_name = None ):
291
291
self ._device_mesh = device_mesh
292
+ self ._batch_dim_name = batch_dim_name
292
293
293
294
def get_data_layout (self , data_shape ):
294
295
"""Retrieve the `TensorLayout` for the input data.
@@ -341,6 +342,10 @@ def scope(self):
341
342
def device_mesh (self ):
342
343
return self ._device_mesh
343
344
345
+ @property
346
+ def batch_dim_name (self ):
347
+ return self ._batch_dim_name
348
+
344
349
def distribute_dataset (self , dataset ):
345
350
"""Create a distributed dataset instance from the original user dataset.
346
351
@@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
395
400
else :
396
401
self ._initialize_mesh_from_list_devices ()
397
402
398
- self ._batch_dim_name = self .device_mesh .axis_names [0 ]
399
403
# Those following attributes might get convert to public methods.
400
404
self ._num_process = distribution_lib .num_processes ()
401
405
self ._process_id = distribution_lib .process_id ()
@@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
408
412
"Expect `mesh` to be an instance of `DeviceMesh`. "
409
413
f"Received: mesh={ device_mesh } (of type { type (device_mesh )} )"
410
414
)
411
- super ().__init__ (device_mesh )
415
+ super ().__init__ (device_mesh , device_mesh . axis_names [ 0 ] )
412
416
if self .device_mesh .devices .ndim != 1 :
413
417
warnings .warn (
414
418
"Expect the input mesh to be 1D, but received "
@@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
424
428
axis_names = [DEFAULT_BATCH_DIM_NAME ],
425
429
devices = devices ,
426
430
)
427
- super ().__init__ (device_mesh )
431
+ super ().__init__ (device_mesh , DEFAULT_BATCH_DIM_NAME )
428
432
429
433
def _initialize_mesh_from_list_devices (self ):
430
434
devices = np .array (list_devices ())
@@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
433
437
axis_names = [DEFAULT_BATCH_DIM_NAME ],
434
438
devices = devices ,
435
439
)
436
- super ().__init__ (device_mesh )
440
+ super ().__init__ (device_mesh , DEFAULT_BATCH_DIM_NAME )
437
441
438
442
def get_data_layout (self , data_shape ):
439
443
data_shard_spec = [None ] * len (data_shape )
440
- data_shard_spec [0 ] = self ._batch_dim_name # Shard on the first dim
444
+ data_shard_spec [0 ] = self .batch_dim_name # Shard on the first dim
441
445
return TensorLayout (data_shard_spec , self .device_mesh )
442
446
443
447
def get_variable_layout (self , variable ):
@@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
590
594
591
595
def get_data_layout (self , data_shape ):
592
596
data_shard_spec = [None ] * len (data_shape )
593
- data_shard_spec [0 ] = self ._batch_dim_name # Shard on the first dim
597
+ data_shard_spec [0 ] = self .batch_dim_name # Shard on the first dim
594
598
return TensorLayout (data_shard_spec , self .device_mesh )
595
599
596
600
def get_variable_layout (self , variable ):
@@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
631
635
# Note that this might be smaller than one if model replicas are sharded
632
636
# across multiple processes.
633
637
mesh_batch_dim_index = self .device_mesh .axis_names .index (
634
- self ._batch_dim_name
638
+ self .batch_dim_name
635
639
)
636
640
num_model_replicas = self .device_mesh .shape [mesh_batch_dim_index ]
637
641
if num_model_replicas == 1 :
0 commit comments