5656def  translate_keras_rs_configuration (
5757    feature_configs : types .Nested [FeatureConfig ],
5858    table_stacking : str  |  Sequence [str ] |  Sequence [Sequence [str ]],
59+     num_replicas_in_sync : int ,
5960) ->  tuple [
6061    types .Nested [tf .tpu .experimental .embedding .FeatureConfig ],
6162    tf .tpu .experimental .embedding .SparseCoreEmbeddingConfig ,
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
7273    """ 
7374    tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ] =  {}
7475    feature_configs  =  keras .tree .map_structure (
75-         lambda  f : translate_keras_rs_feature_config (f , tables ), feature_configs 
76+         lambda  f : translate_keras_rs_feature_config (
77+             f , tables , num_replicas_in_sync 
78+         ),
79+         feature_configs ,
7680    )
7781
7882    # max_ids_per_chip_per_sample 
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
107111def  translate_keras_rs_feature_config (
108112    feature_config : FeatureConfig ,
109113    tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ],
114+     num_replicas_in_sync : int ,
110115) ->  tf .tpu .experimental .embedding .FeatureConfig :
111116    """Translates a Keras RS feature config to a TensorFlow TPU feature config. 
112117
@@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
120125    Returns: 
121126      The TensorFlow TPU feature config. 
122127    """ 
128+     if  num_replicas_in_sync  <=  0 :
129+         raise  ValueError (
130+             "`num_replicas_in_sync` must be positive, " 
131+             f"but got { num_replicas_in_sync }  ." 
132+         )
133+ 
123134    table  =  tables .get (feature_config .table , None )
124135    if  table  is  None :
125136        table  =  translate_keras_rs_table_config (feature_config .table )
126137        tables [feature_config .table ] =  table 
127138
139+     if  len (feature_config .output_shape ) <  2 :
140+         raise  ValueError (
141+             f"Invalid `output_shape` { feature_config .output_shape }   in " 
142+             f"`FeatureConfig` { feature_config }  . It must have at least 2 " 
143+             "dimensions: a batch dimension and an embedding dimension." 
144+         )
145+ 
146+     # Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it. 
147+     output_shape  =  list (feature_config .output_shape [0 :- 1 ])
148+ 
149+     batch_size  =  output_shape [0 ]
150+     per_replica_batch_size : int  |  None  =  None 
151+     if  batch_size  is  not   None :
152+         if  batch_size  %  num_replicas_in_sync  !=  0 :
153+             raise  ValueError (
154+                 f"Invalid `output_shape` { feature_config .output_shape }   in " 
155+                 f"`FeatureConfig` { feature_config }  . Batch size { batch_size }   is " 
156+                 f"not a multiple of the number of TPUs { num_replicas_in_sync }  ." 
157+             )
158+         per_replica_batch_size  =  batch_size  //  num_replicas_in_sync 
159+ 
160+     # TensorFlow's TPUEmbedding wants the per replica batch size. 
161+     output_shape  =  [per_replica_batch_size ] +  output_shape [1 :]
162+ 
128163    # max_sequence_length 
129164    return  tf .tpu .experimental .embedding .FeatureConfig (
130165        name = feature_config .name ,
131166        table = table ,
132-         output_shape = feature_config .output_shape [
133-             0 :- 1 
134-         ],  # exclude last dimension 
167+         output_shape = output_shape ,
135168    )
136169
137170
0 commit comments