diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc index cbf73b63..27b5546f 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc @@ -963,14 +963,13 @@ void ValidateMinibatchOrSparseCoreSlice( int ids_count = 0; absl::flat_hash_set unique_ids; for (int j = start_index; j < end_index; ++j) { - if (embedding_ids_slice(j) != INT_MAX) { - ValidateCooId(embedding_ids_slice(j), sample_ids_slice(j), - table_shard_size, batch_size_per_sc); - ids_count++; - unique_ids.insert(embedding_ids_slice(j)); - // ASSERT_LE(ids_count, max_ids_per_partition); - // ASSERT_LE(unique_ids.size(), max_unique_ids_per_partition); - } + ASSERT_NE(embedding_ids_slice(j), INT_MAX); + ValidateCooId(embedding_ids_slice(j), sample_ids_slice(j), + table_shard_size, batch_size_per_sc); + ids_count++; + unique_ids.insert(embedding_ids_slice(j)); + ASSERT_LE(ids_count, max_ids_per_partition); + ASSERT_LE(unique_ids.size(), max_unique_ids_per_partition); } start_index = xla::RoundUpTo(end_index, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); } diff --git a/jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_impl.h b/jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_impl.h index 6d194822..398b6ca4 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/minibatching_splits_impl.h @@ -14,6 +14,7 @@ #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_SPLITS_IMPL_H_ #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_MINIBATCHING_SPLITS_IMPL_H_ +#include #include #include #include @@ -72,6 +73,7 @@ std::bitset ComputeMinibatchingSplit( const int val_right = unique_ids_per_bucket[i + subtree_size / 2]; if (val_left + val_right > max_unique_ids_per_partition) { split.set(split_index); + unique_ids_per_bucket[i] = std::max(val_left, val_right); } else { unique_ids_per_bucket[i] += val_right; }