diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 136052137..8a1c10296 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -301,6 +301,11 @@ def _get_writable_feature_hash_sizes(self) -> List[int]: return feature_hash_sizes def _get_virtual_table_feature_num_buckets(self) -> List[int]: + """ + Returns the number of buckets for each KVZCH feature in the GroupedEmbeddingConfigs. + If a feature is not a KVZCH feature, the list will have world_size for that feature's corresponding position. + This is needed as KVZCH features have to be processed for input_dist with non-KVZCH features. + """ feature_num_buckets: List[int] = [] for group_config in self._grouped_embedding_configs: for embedding_table in group_config.embedding_tables: @@ -312,6 +317,10 @@ def _get_virtual_table_feature_num_buckets(self) -> List[int]: [embedding_table.total_num_buckets] * embedding_table.num_features() ) + else: + feature_num_buckets.extend( + [self._world_size] * embedding_table.num_features() + ) return feature_num_buckets