Skip to content

Commit 2fbdc2c

Browse files
authored
DistributedEmbedding: do not hardcode the number of SparseCores with TF. (#147)
The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the `distributed_embedding_tests.py` pass on V6e.
1 parent 46543dc commit 2fbdc2c

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,9 @@ def test_correctness(
611611

612612
self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
613613

614-
tables = layer.get_embedding_tables()
614+
with self._strategy.scope():
615+
tables = layer.get_embedding_tables()
616+
615617
emb = tables["table"]
616618

617619
if input_type == "dense":

keras_rs/src/layers/embedding/tensorflow/distributed_embedding.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,18 @@ def _sparsecore_call(
281281
def _sparsecore_get_embedding_tables(self) -> dict[str, types.Tensor]:
282282
tables: dict[str, types.Tensor] = {}
283283
strategy = tf.distribute.get_strategy()
284-
# 4 is the number of sparsecores per chip
285-
num_shards = strategy.num_replicas_in_sync * 4
284+
if not self._is_tpu_strategy(strategy):
285+
raise RuntimeError(
286+
"`DistributedEmbedding.get_embedding_tables` needs to be "
287+
"called under the TPUStrategy that DistributedEmbedding was "
288+
f"created with, but is being called under strategy {strategy}. "
289+
"Please use `with strategy.scope()` when calling "
290+
"`get_embedding_tables`."
291+
)
292+
293+
tpu_hardware = strategy.extended.tpu_hardware_feature
294+
num_sc_per_device = tpu_hardware.num_embedding_devices_per_chip
295+
num_shards = strategy.num_replicas_in_sync * num_sc_per_device
286296

287297
def populate_table(
288298
feature_config: tf.tpu.experimental.embedding.FeatureConfig,

0 commit comments

Comments
 (0)