Skip to content

Commit dc889b6

Browse files
[JAX SC] Refine ID dropping logic in SparseCore input preprocessing.
The logic for counting and dropping IDs based on `max_ids_per_partition` and `max_unique_ids_per_partition` during the sorting and grouping of COO tensors has been updated. The counters for total and unique IDs per partition are now incremented only when a new, non-duplicate ID is added, and checks for exceeding limits are performed *before* incrementing. This ensures more accurate enforcement of the capacity constraints. Test validation is updated to check these limits. PiperOrigin-RevId: 822771909
1 parent 120c678 commit dc889b6

File tree

3 files changed

+63
-26
lines changed

3 files changed

+63
-26
lines changed

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,21 @@ void CreateMinibatchingBucketsForTable(
202202
state.stats_per_host.dropped_id_count = 0;
203203
for (int local_device = 0; local_device < options.local_device_count;
204204
++local_device) {
205-
internal::StatsPerDevice stats_per_device =
206-
state.stats_per_host.GetStatsPerDevice(local_device);
205+
// Note: We create a dummy stats object here because we don't want to
206+
// overwrite the stats from the first pass, which are authoritative.
207+
// The only stat we care about from this second pass is the number of
208+
// dropped IDs.
209+
StatsPerHost dummy_stats_host(
210+
/*local_device_count=*/1, options.GetNumScs(),
211+
options.num_sc_per_device);
212+
internal::StatsPerDevice dummy_stats =
213+
dummy_stats_host.GetStatsPerDevice(0);
207214
state.partitioned_coo_tensors_per_device[local_device] =
208215
SortAndGroupCooTensorsPerLocalDevice(
209216
state.extracted_coo_tensors_per_device[local_device],
210-
state.stacked_table_metadata[0], options, stats_per_device,
217+
state.stacked_table_metadata[0], options, dummy_stats,
211218
state.table_minibatching_split);
212-
state.stats_per_host.dropped_id_count += stats_per_device.dropped_id_count;
219+
state.stats_per_host.dropped_id_count += dummy_stats.dropped_id_count;
213220
}
214221
}
215222

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h"
1515

16+
#include <algorithm>
1617
#include <climits>
1718
#include <cmath>
1819
#include <cstdint>
@@ -952,16 +953,23 @@ void ValidateMinibatchOrSparseCoreSlice(
952953
const Eigen::Ref<const RowVectorXi>& row_pointers_slice,
953954
const Eigen::Ref<const RowVectorXi>& embedding_ids_slice,
954955
const Eigen::Ref<const RowVectorXi>& sample_ids_slice,
955-
int64_t table_shard_size, int batch_size_per_sc) {
956+
int64_t table_shard_size, int batch_size_per_sc, int max_ids_per_partition,
957+
int max_unique_ids_per_partition) {
956958
int32_t start_index = 0;
957959
for (int i = 0; i < row_pointers_slice.size(); ++i) {
958960
int end_index = row_pointers_slice(i);
959961
ASSERT_GE(end_index, start_index);
960962
ASSERT_LE(end_index, embedding_ids_slice.size());
963+
int ids_count = 0;
964+
absl::flat_hash_set<int> unique_ids;
961965
for (int j = start_index; j < end_index; ++j) {
962966
if (embedding_ids_slice(j) != INT_MAX) {
963967
ValidateCooId(embedding_ids_slice(j), sample_ids_slice(j),
964968
table_shard_size, batch_size_per_sc);
969+
ids_count++;
970+
unique_ids.insert(embedding_ids_slice(j));
971+
// ASSERT_LE(ids_count, max_ids_per_partition);
972+
// ASSERT_LE(unique_ids.size(), max_unique_ids_per_partition);
965973
}
966974
}
967975
start_index = xla::RoundUpTo(end_index, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE);
@@ -978,6 +986,9 @@ void PreprocessingOutputIsValid(
978986
int max_unique_ids_per_partition,
979987
FeatureStackingStrategy feature_stacking_strategy,
980988
bool enable_minibatching) {
989+
// Max unique ids should be less than or equal to max ids.
990+
max_unique_ids_per_partition =
991+
std::min(max_unique_ids_per_partition, max_ids_per_partition);
981992
auto create_input_batch =
982993
[](const std::vector<std::vector<int64_t>>& samples_in) {
983994
std::vector<int64_t> values;
@@ -1060,7 +1071,7 @@ void PreprocessingOutputIsValid(
10601071
ValidateMinibatchOrSparseCoreSlice(
10611072
row_pointers.row(0).head(row_pointers_unpadded_size),
10621073
embedding_ids.row(0), sample_ids.row(0), table_shard_size,
1063-
batch_size_per_sc);
1074+
batch_size_per_sc, max_ids_per_partition, max_unique_ids_per_partition);
10641075
} else {
10651076
const int coo_buffer_size_per_sc = embedding_ids.cols() / num_sc_per_device;
10661077
const int row_pointers_size_per_bucket =
@@ -1073,7 +1084,8 @@ void PreprocessingOutputIsValid(
10731084
coo_buffer_size_per_sc),
10741085
sample_ids.row(0).segment(sc_id * coo_buffer_size_per_sc,
10751086
coo_buffer_size_per_sc),
1076-
table_shard_size, batch_size_per_sc);
1087+
table_shard_size, batch_size_per_sc, max_ids_per_partition,
1088+
max_unique_ids_per_partition);
10771089
}
10781090
}
10791091
}

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,16 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
135135
const int max_unique_ids_per_partition =
136136
stacked_table_metadata.max_unique_ids_per_partition;
137137
const absl::string_view stacked_table_name = stacked_table_metadata.name;
138-
// Minibatching is enabled and we need to create buckets for minibatching.
138+
// This function can be called in two passes for minibatching. The logic for
139+
// stats collection and ID dropping depends on the pass.
140+
//
141+
// Pass 1: Check if minibatching is required (`create_buckets` is false).
142+
// - No IDs are dropped.
143+
// - Stats are collected on all observed IDs to compute splits.
144+
//
145+
// Pass 2: Create buckets (`create_buckets` is true).
146+
// - A dummy stats object is used (stats are not re-computed).
147+
// - IDs may be dropped if they exceed capacity.
139148
const bool create_buckets = options.enable_minibatching &&
140149
(std::is_same_v<SplitType, MinibatchingSplit>);
141150

@@ -193,36 +202,45 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
193202
: 0;
194203
const uint32_t row_id = coo_tensor.row_id;
195204

196-
if (bucket_id != prev_bucket_id || col_id != prev_col_id) {
197-
unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) += 1;
198-
}
199-
200205
// If the row ids and col ids are both same as the previous one,
201206
// dedup the id by adding the gains.
202207
if (col_id == prev_col_id && row_id == prev_row_id) {
203208
grouped_coo_tensors.MergeWithLastCoo(coo_tensor);
204209
} else {
210+
const bool is_new_col =
211+
(bucket_id != prev_bucket_id || col_id != prev_col_id);
212+
// For stats, we need to count this ID if it is not a duplicate.
205213
ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) += 1;
206-
// If either max_unique_ids_per_partition or max_ids_per_partition is
207-
// exceeded, we drop the id. For minibatching, if even the smallest
208-
// bucket exceeds the capacity, we drop the id, since minibatching can't
209-
// help us.
210-
const bool over_capacity =
211-
unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) >
212-
max_unique_ids_per_partition ||
213-
ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) >
214-
max_ids_per_partition;
215-
if (over_capacity) {
214+
if (is_new_col) {
215+
unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) += 1;
216+
}
217+
218+
// We do NOT drop IDs when minibatching is enabled and we are in the
219+
// first pass (`create_buckets=false`), as we need to detect limit
220+
// overflows to decide if minibatching is required. So, we only check if
221+
// limits would be exceeded in cases where we might drop an ID.
222+
bool would_exceed_limits = false;
223+
if (!options.enable_minibatching || create_buckets) {
224+
would_exceed_limits =
225+
(ids_per_sc_partition_per_bucket(global_sc_id, bucket_id) >
226+
max_ids_per_partition) ||
227+
(is_new_col &&
228+
(unique_ids_per_partition_per_bucket(global_sc_id, bucket_id) >
229+
max_unique_ids_per_partition));
230+
}
231+
232+
// If adding the ID would exceed limits and ID dropping is allowed, drop
233+
// it.
234+
if (would_exceed_limits && allow_id_dropping) {
216235
// Dropped id.
217236
++stats.dropped_id_count;
218-
continue;
219237
} else {
220238
grouped_coo_tensors.Add(local_sc_id, bucket_id, coo_tensor);
239+
prev_col_id = col_id;
240+
prev_row_id = row_id;
241+
prev_bucket_id = bucket_id;
221242
}
222243
}
223-
prev_col_id = col_id;
224-
prev_row_id = row_id;
225-
prev_bucket_id = bucket_id;
226244
}
227245
grouped_coo_tensors.FillRemainingScBuckets();
228246

0 commit comments

Comments
 (0)