Skip to content

Commit fbf0af7

Browse files
Fixing batch_dim_name attribute (#20674)
* fixing wrong trainer assumption that batch dim is always the first one in the mesh * need functools partial * lint * fix test failure when distribution=None * lint2 * fix for test failure * added data sharding for 3D+ meshes * lint3 * added @Property for batch_dim_name + refactoring * fix typo
1 parent ab3c8f5 commit fbf0af7

File tree

5 files changed

+32
-18
lines changed

5 files changed

+32
-18
lines changed

keras/src/backend/jax/distribution_lib.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
100100
return global_value
101101

102102

103-
def distribute_data_input(per_process_batch, layout):
103+
def distribute_data_input(per_process_batch, layout, batch_dim_name):
104104
"""Distribute the input data with the corresponding layout.
105105
106106
Note that the inputs here is a local worker batch. Within the local worker,
@@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
117117
if not isinstance(layout, jax.sharding.Sharding):
118118
layout = _to_jax_layout(layout)
119119

120-
mesh_shape = list(layout.mesh.shape.values())
121-
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
122-
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
120+
num_model_replicas_total = layout.mesh.shape[batch_dim_name]
121+
122+
mesh_model_dim_size = 1
123+
for name, dim_size in layout.mesh.shape.items():
124+
if not name == batch_dim_name:
125+
mesh_model_dim_size *= dim_size
126+
123127
num_model_replicas_per_process = num_model_replicas_total / num_processes()
124128
per_process_batch_size = per_process_batch.shape[0]
125129

keras/src/backend/jax/distribution_lib_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,9 @@ def test_distribute_data_input(self):
337337
mesh, jax.sharding.PartitionSpec("batch", None)
338338
)
339339

340-
result = backend_dlib.distribute_data_input(per_process_batch, layout)
340+
result = backend_dlib.distribute_data_input(
341+
per_process_batch, layout, "batch"
342+
)
341343

342344
# Check the shape of the global batch array
343345
self.assertEqual(

keras/src/backend/jax/trainer.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import collections
22
import itertools
3+
from functools import partial
34

45
import jax
56
import numpy as np
@@ -988,15 +989,18 @@ def _get_jax_state(
988989

989990
def _distribute_data(data, layouts=None):
990991
distribution = distribution_lib.distribution()
992+
991993
if distribution is not None:
992994
if layouts is None:
993995
layouts = tree.map_structure(
994996
lambda d: distribution.get_data_layout(d.shape),
995997
data,
996998
)
997-
return tree.map_structure(
998-
jax_distribution_lib.distribute_data_input, data, layouts
999+
jax_dist_data_input = partial(
1000+
jax_distribution_lib.distribute_data_input,
1001+
batch_dim_name=distribution.batch_dim_name,
9991002
)
1003+
return tree.map_structure(jax_dist_data_input, data, layouts)
10001004

10011005
return tree.map_structure(jax.device_put, data)
10021006

keras/src/distribution/distribution_lib.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,9 @@ class Distribution:
287287
device_mesh: A `DeviceMesh` instance.
288288
"""
289289

290-
def __init__(self, device_mesh):
290+
def __init__(self, device_mesh, batch_dim_name=None):
291291
self._device_mesh = device_mesh
292+
self._batch_dim_name = batch_dim_name
292293

293294
def get_data_layout(self, data_shape):
294295
"""Retrieve the `TensorLayout` for the input data.
@@ -341,6 +342,10 @@ def scope(self):
341342
def device_mesh(self):
342343
return self._device_mesh
343344

345+
@property
346+
def batch_dim_name(self):
347+
return self._batch_dim_name
348+
344349
def distribute_dataset(self, dataset):
345350
"""Create a distributed dataset instance from the original user dataset.
346351
@@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
395400
else:
396401
self._initialize_mesh_from_list_devices()
397402

398-
self._batch_dim_name = self.device_mesh.axis_names[0]
399403
# Those following attributes might get convert to public methods.
400404
self._num_process = distribution_lib.num_processes()
401405
self._process_id = distribution_lib.process_id()
@@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
408412
"Expect `mesh` to be an instance of `DeviceMesh`. "
409413
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
410414
)
411-
super().__init__(device_mesh)
415+
super().__init__(device_mesh, device_mesh.axis_names[0])
412416
if self.device_mesh.devices.ndim != 1:
413417
warnings.warn(
414418
"Expect the input mesh to be 1D, but received "
@@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
424428
axis_names=[DEFAULT_BATCH_DIM_NAME],
425429
devices=devices,
426430
)
427-
super().__init__(device_mesh)
431+
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)
428432

429433
def _initialize_mesh_from_list_devices(self):
430434
devices = np.array(list_devices())
@@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
433437
axis_names=[DEFAULT_BATCH_DIM_NAME],
434438
devices=devices,
435439
)
436-
super().__init__(device_mesh)
440+
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)
437441

438442
def get_data_layout(self, data_shape):
439443
data_shard_spec = [None] * len(data_shape)
440-
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
444+
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
441445
return TensorLayout(data_shard_spec, self.device_mesh)
442446

443447
def get_variable_layout(self, variable):
@@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):
590594

591595
def get_data_layout(self, data_shape):
592596
data_shard_spec = [None] * len(data_shape)
593-
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
597+
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
594598
return TensorLayout(data_shard_spec, self.device_mesh)
595599

596600
def get_variable_layout(self, variable):
@@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
631635
# Note that this might be smaller than one if model replicas are sharded
632636
# across multiple processes.
633637
mesh_batch_dim_index = self.device_mesh.axis_names.index(
634-
self._batch_dim_name
638+
self.batch_dim_name
635639
)
636640
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]
637641
if num_model_replicas == 1:

keras/src/distribution/distribution_lib_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_create_with_device_mesh(self):
186186
device_mesh = distribution.device_mesh
187187
self.assertEqual(len(device_mesh.devices), 8)
188188
self.assertEqual(device_mesh.axis_names, ["data"])
189-
self.assertEqual(distribution._batch_dim_name, "data")
189+
self.assertEqual(distribution.batch_dim_name, "data")
190190

191191
self.assertFalse(distribution._is_multi_process)
192192
self.assertEqual(distribution._process_id, 0)
@@ -197,7 +197,7 @@ def test_create_with_devices(self):
197197
device_mesh = distribution.device_mesh
198198
self.assertEqual(len(device_mesh.devices), 8)
199199
self.assertEqual(device_mesh.axis_names, ["batch"])
200-
self.assertEqual(distribution._batch_dim_name, "batch")
200+
self.assertEqual(distribution.batch_dim_name, "batch")
201201

202202
@mock.patch.object(
203203
distribution_lib,
@@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices):
211211
device_mesh = distribution.device_mesh
212212
self.assertEqual(len(device_mesh.devices), 8)
213213
self.assertEqual(device_mesh.axis_names, ["batch"])
214-
self.assertEqual(distribution._batch_dim_name, "batch")
214+
self.assertEqual(distribution.batch_dim_name, "batch")
215215

216216
def test_get_data_layout(self):
217217
distribution = distribution_lib.DataParallel(

0 commit comments

Comments
 (0)