Skip to content

Commit 6b5d25b

Browse files
authored
[magpie][context audio] add speaker items limit to compute similarity matrix to avoid OOM. (#14833)
* [step2] add speaker items limit to compute similarity matrix to avoid OOM of CPU ram. Signed-off-by: Xuesong Yang <[email protected]> * defined a constant MASKED_SIMILAIRY_VALUE to cover the magic number -2.0 for better maintainability. Signed-off-by: Xuesong Yang <[email protected]> --------- Signed-off-by: Xuesong Yang <[email protected]>
1 parent 55d5a01 commit 6b5d25b

File tree

1 file changed

+79
-21
lines changed

1 file changed

+79
-21
lines changed

scripts/magpietts/extend_nemo_manifest_with_context_audio.py

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os
19+
import random
1920
import re
2021
import time
2122
from collections import defaultdict
@@ -35,6 +36,10 @@
3536

3637
logger = logging.getLogger(__name__)
3738

39+
# Constant for masking identical items in similarity matrix
40+
# Set below valid cosine similarity range [-1, 1] to ensure masked items are never selected
41+
MASKED_SIMILARITY_VALUE = -2.0
42+
3843
"""
3944
Usage:
4045
python scripts/magpietts/extend_manifest_with_context_audio.py
@@ -48,10 +53,16 @@
4853
--num-workers 4
4954
--context-min-duration 3.0
5055
--context-min-ssim 0.6
56+
--max-speaker-items 20000 # optional, prevents OOM for large speakers
5157
5258
This script distributes speakers across DDP ranks. Each rank processes its assigned speakers
5359
and writes a partial manifest. Rank 0 then merges these into a final manifest.
5460
61+
The --max-speaker-items parameter limits the size of the context pool per speaker to prevent OOM
62+
when computing similarity matrices. If a speaker has more items than this limit, a random
63+
sample will be used as the context pool, but all items will still be processed to find
64+
their best context from this pool.
65+
5566
Input manifest example entry:
5667
{
5768
"audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav",
@@ -183,6 +194,7 @@ def __init__(
183194
context_min_ssim: float,
184195
speaker_expected_counts_map: dict,
185196
initial_assigned_count: int,
197+
max_speaker_items: int = None,
186198
):
187199
super().__init__()
188200
self.sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
@@ -197,6 +209,7 @@ def __init__(
197209
self.context_min_ssim = context_min_ssim
198210
self.speaker_expected_counts = speaker_expected_counts_map
199211
self.initial_assigned_count = initial_assigned_count
212+
self.max_speaker_items = max_speaker_items
200213

201214
# Per-rank attributes
202215
self.output_file_path = None
@@ -216,6 +229,10 @@ def setup(self, stage: str):
216229
self.output_dir.mkdir(parents=True, exist_ok=True)
217230
self.output_manifest_file = open(self.output_file_path, "w", encoding="utf-8")
218231
logger.info(f"Writing partial manifest to: `{self.output_file_path}`")
232+
if self.max_speaker_items:
233+
logger.info(f"Max speaker items limit set to: {self.max_speaker_items}")
234+
else:
235+
logger.info("No max speaker items limit set (potential OOM risk for very large speakers)")
219236
logger.debug(f"Expected speaker counts for model: {self.speaker_expected_counts}")
220237

221238
def forward(self, batch):
@@ -295,46 +312,80 @@ def _process_and_flush_speakers_local(self):
295312
self.total_accumulated_items -= len(speaker_items)
296313
self.processed_speakers_set.add(speaker_id)
297314

298-
# NOTE: Potential OOM (Out Of Memory) risk if a single speaker has an extremely large
299-
# number of segments (e.g., tens of thousands). The N x N similarity matrix calculated below
300-
# (where N = len(speaker_items)) can consume significant CPU RAM.
301-
# For example, 50,000 segments for one speaker could lead to a float32 similarity matrix
302-
# requiring approximately 10 GB of RAM. Consider this if processing datasets with
303-
# speakers having a very high number of utterances.
304-
embeddings = torch.stack([item['embedding'] for item in speaker_items])
305-
embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1)
306-
similarity_matrix = torch.matmul(embeddings_norm, embeddings_norm.transpose(0, 1))
307-
similarity_matrix.fill_diagonal_(-2.0) # cosine similarity range is [-1, 1]
315+
# Apply speaker size limit to prevent OOM while processing all items
316+
all_items_to_process = speaker_items # We want to process ALL items
317+
318+
# Create context pool with original indices for easy identification
319+
if self.max_speaker_items and len(speaker_items) > self.max_speaker_items:
320+
logger.warning(
321+
f"Speaker {speaker_id} has {len(speaker_items)} items, exceeding max limit of {self.max_speaker_items}. "
322+
f"Using random sample of {self.max_speaker_items} items as context pool, but processing all {len(speaker_items)} items."
323+
)
324+
# Randomly sample with original indices preserved
325+
random.seed(12345) # For reproducibility
326+
indexed_items = [(idx, item) for idx, item in enumerate(speaker_items)]
327+
sampled_indexed_items = random.sample(indexed_items, self.max_speaker_items)
328+
context_pool_items = [item for _, item in sampled_indexed_items]
329+
context_pool_original_indices = [idx for idx, _ in sampled_indexed_items]
330+
else:
331+
# Use all items as context pool
332+
context_pool_items = speaker_items
333+
context_pool_original_indices = list(range(len(speaker_items)))
334+
335+
# NOTE: Now we compute similarities between ALL items and the context pool.
336+
# This limits the similarity matrix to N×M instead of N×N where M <= max_speaker_items.
337+
# Memory usage: N×M×4 bytes instead of N×N×4 bytes.
338+
all_embeddings = torch.stack([item['embedding'] for item in all_items_to_process])
339+
context_embeddings = torch.stack([item['embedding'] for item in context_pool_items])
340+
341+
all_embeddings_norm = torch.nn.functional.normalize(all_embeddings, p=2, dim=1)
342+
context_embeddings_norm = torch.nn.functional.normalize(context_embeddings, p=2, dim=1)
343+
344+
# Compute N×M similarity matrix: each row is similarities for one item against all context candidates
345+
similarity_matrix = torch.matmul(all_embeddings_norm, context_embeddings_norm.transpose(0, 1))
346+
347+
# Mask positions where items are identical (same item appearing in both N and M sets)
348+
# Using original indices as identifiers. This prevents an item from being selected as its own context.
349+
# Create a mapping from original indices to context pool positions
350+
original_index_to_context_position = {}
351+
for context_pos, original_idx in enumerate(context_pool_original_indices):
352+
original_index_to_context_position[original_idx] = context_pos
353+
354+
# Mask similarities for identical items
355+
for n_idx in range(len(all_items_to_process)):
356+
if n_idx in original_index_to_context_position:
357+
context_pos = original_index_to_context_position[n_idx]
358+
similarity_matrix[n_idx, context_pos] = MASKED_SIMILARITY_VALUE
308359

309360
# Sort all similarities for each item to iterate through candidates
310-
# best_similarities_tensor will contain sorted similarities for each row (original item)
311-
# best_indices_tensor will contain original indices of these sorted items
361+
# sorted_similarities_tensor will contain sorted similarities for each row (original item)
362+
# sorted_indices_tensor will contain indices in the context_pool
312363
sorted_similarities_tensor, sorted_indices_tensor = torch.sort(similarity_matrix, dim=1, descending=True)
313364

314-
record_preparation_start_time = time.time()
315365
num_records_written_for_speaker = 0
316366
# Initialize a counter for items discarded for this specific speaker
317367
num_discarded_for_this_speaker_no_context = 0
318368

319-
for i, current_item_data in enumerate(speaker_items):
369+
for i, current_item_data in enumerate(all_items_to_process):
320370
output_record = current_item_data['metadata'].copy()
321371
write_this_record = False
322372

323-
# Iterate through potential candidates, sorted by similarity
373+
# Iterate through potential candidates from context pool, sorted by similarity
324374
for candidate_rank in range(sorted_indices_tensor.size(1)):
325375
candidate_ssim = sorted_similarities_tensor[i, candidate_rank].item()
326-
original_candidate_idx = sorted_indices_tensor[i, candidate_rank].item()
376+
context_pool_idx = sorted_indices_tensor[i, candidate_rank].item()
327377

328-
# Skip if candidate is the item itself (safeguard)
329-
if original_candidate_idx == i:
330-
continue
378+
# if ANY candidate has similarity ≤ MASKED_SIMILARITY_VALUE, all subsequent ones will be ≤ MASKED_SIMILARITY_VALUE
379+
# since similarities are sorted in descending order, we can break early
380+
if candidate_ssim <= MASKED_SIMILARITY_VALUE:
381+
break
331382

332-
# If SSIM is below threshold, stop searching for this item (since candidates are sorted)
383+
# If SSIM is below threshold, stop searching for this item
333384
if candidate_ssim < self.context_min_ssim:
334385
break
335386

336387
# Check duration if SSIM is acceptable
337-
best_meta_dict = speaker_items[original_candidate_idx]['metadata']
388+
best_meta_dict = context_pool_items[context_pool_idx]['metadata']
338389
candidate_duration = best_meta_dict["duration"]
339390

340391
if candidate_duration >= self.context_min_duration:
@@ -563,6 +614,12 @@ def main():
563614
parser.add_argument(
564615
"--context-min-ssim", type=float, default=0.6, help="Minimum cosine similarity for a context audio segment."
565616
)
617+
parser.add_argument(
618+
"--max-speaker-items",
619+
type=int,
620+
default=None,
621+
help="Maximum size of context pool per speaker to prevent OOM. If a speaker has more items, a random sample will be used as context pool, but all items will still be processed. Default: None (no limit, potential OOM risk).",
622+
)
566623
parser.add_argument("--devices", type=int, default=-1)
567624
parser.add_argument("--num-nodes", type=int, default=1)
568625
parser.add_argument("--batch-size", type=int, default=16)
@@ -837,6 +894,7 @@ def main():
837894
context_min_ssim=args.context_min_ssim,
838895
speaker_expected_counts_map=my_speaker_expected_counts,
839896
initial_assigned_count=len(assigned_records_for_this_rank),
897+
max_speaker_items=args.max_speaker_items,
840898
)
841899
logger.info(
842900
f"Starting prediction with {len(assigned_records_for_this_rank)} records ({len(my_speaker_expected_counts)} unique speakers for this rank according to counts)."

0 commit comments

Comments
 (0)