@@ -135,7 +135,16 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
135135 const int max_unique_ids_per_partition =
136136 stacked_table_metadata.max_unique_ids_per_partition ;
137137 const absl::string_view stacked_table_name = stacked_table_metadata.name ;
138- // Minibatching is enabled and we need to create buckets for minibatching.
138+ // This function can be called in two passes for minibatching. The logic for
139+ // stats collection and ID dropping depends on the pass.
140+ //
141+ // Pass 1: Check if minibatching is required (`create_buckets` is false).
142+ // - No IDs are dropped.
143+ // - Stats are collected on all observed IDs to compute splits.
144+ //
145+ // Pass 2: Create buckets (`create_buckets` is true).
146+ // - A dummy stats object is used (stats are not re-computed).
147+ // - IDs may be dropped if they exceed capacity.
139148 const bool create_buckets = options.enable_minibatching &&
140149 (std::is_same_v<SplitType, MinibatchingSplit>);
141150
@@ -193,36 +202,45 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
193202 : 0 ;
194203 const uint32_t row_id = coo_tensor.row_id ;
195204
196- if (bucket_id != prev_bucket_id || col_id != prev_col_id) {
197- unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
198- }
199-
200205 // If the row ids and col ids are both same as the previous one,
201206 // dedup the id by adding the gains.
202207 if (col_id == prev_col_id && row_id == prev_row_id) {
203208 grouped_coo_tensors.MergeWithLastCoo (coo_tensor);
204209 } else {
210+ const bool is_new_col =
211+ (bucket_id != prev_bucket_id || col_id != prev_col_id);
212+ // For stats, we need to count this ID if it is not a duplicate.
205213 ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
206- // If either max_unique_ids_per_partition or max_ids_per_partition is
207- // exceeded, we drop the id. For minibatching, if even the smallest
208- // bucket exceeds the capacity, we drop the id, since minibatching can't
209- // help us.
210- const bool over_capacity =
211- unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) >
212- max_unique_ids_per_partition ||
213- ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) >
214- max_ids_per_partition;
215- if (over_capacity) {
214+ if (is_new_col) {
215+ unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) += 1 ;
216+ }
217+
218+ // We do NOT drop IDs when minibatching is enabled and we are in the
219+ // first pass (`create_buckets=false`), as we need to detect limit
220+ // overflows to decide if minibatching is required. So, we only check if
221+ // limits would be exceeded in cases where we might drop an ID.
222+ bool would_exceed_limits = false ;
223+ if (!options.enable_minibatching || create_buckets) {
224+ would_exceed_limits =
225+ (ids_per_sc_partition_per_bucket (global_sc_id, bucket_id) >
226+ max_ids_per_partition) ||
227+ (is_new_col &&
228+ (unique_ids_per_partition_per_bucket (global_sc_id, bucket_id) >
229+ max_unique_ids_per_partition));
230+ }
231+
232+ // If adding the ID would exceed limits and ID dropping is allowed, drop
233+ // it.
234+ if (would_exceed_limits && allow_id_dropping) {
216235 // Dropped id.
217236 ++stats.dropped_id_count ;
218- continue ;
219237 } else {
220238 grouped_coo_tensors.Add (local_sc_id, bucket_id, coo_tensor);
239+ prev_col_id = col_id;
240+ prev_row_id = row_id;
241+ prev_bucket_id = bucket_id;
221242 }
222243 }
223- prev_col_id = col_id;
224- prev_row_id = row_id;
225- prev_bucket_id = bucket_id;
226244 }
227245 grouped_coo_tensors.FillRemainingScBuckets ();
228246
0 commit comments