16
16
import json
17
17
import logging
18
18
import os
19
+ import random
19
20
import re
20
21
import time
21
22
from collections import defaultdict
35
36
36
37
logger = logging .getLogger (__name__ )
37
38
39
+ # Constant for masking identical items in similarity matrix
40
+ # Set below valid cosine similarity range [-1, 1] to ensure masked items are never selected
41
+ MASKED_SIMILARITY_VALUE = - 2.0
42
+
38
43
"""
39
44
Usage:
40
45
python scripts/magpietts/extend_manifest_with_context_audio.py
48
53
--num-workers 4
49
54
--context-min-duration 3.0
50
55
--context-min-ssim 0.6
56
+ --max-speaker-items 20000 # optional, prevents OOM for large speakers
51
57
52
58
This script distributes speakers across DDP ranks. Each rank processes its assigned speakers
53
59
and writes a partial manifest. Rank 0 then merges these into a final manifest.
54
60
61
+ The --max-speaker-items parameter limits the size of the context pool per speaker to prevent OOM
62
+ when computing similarity matrices. If a speaker has more items than this limit, a random
63
+ sample will be used as the context pool, but all items will still be processed to find
64
+ their best context from this pool.
65
+
55
66
Input manifest example entry:
56
67
{
57
68
"audio_filepath": "NVYT_40K_audios_wav/_8Kirz57BTY.wav",
@@ -183,6 +194,7 @@ def __init__(
183
194
context_min_ssim : float ,
184
195
speaker_expected_counts_map : dict ,
185
196
initial_assigned_count : int ,
197
+ max_speaker_items : int = None ,
186
198
):
187
199
super ().__init__ ()
188
200
self .sv_model = nemo_asr .models .EncDecSpeakerLabelModel .from_pretrained (
@@ -197,6 +209,7 @@ def __init__(
197
209
self .context_min_ssim = context_min_ssim
198
210
self .speaker_expected_counts = speaker_expected_counts_map
199
211
self .initial_assigned_count = initial_assigned_count
212
+ self .max_speaker_items = max_speaker_items
200
213
201
214
# Per-rank attributes
202
215
self .output_file_path = None
@@ -216,6 +229,10 @@ def setup(self, stage: str):
216
229
self .output_dir .mkdir (parents = True , exist_ok = True )
217
230
self .output_manifest_file = open (self .output_file_path , "w" , encoding = "utf-8" )
218
231
logger .info (f"Writing partial manifest to: `{ self .output_file_path } `" )
232
+ if self .max_speaker_items :
233
+ logger .info (f"Max speaker items limit set to: { self .max_speaker_items } " )
234
+ else :
235
+ logger .info ("No max speaker items limit set (potential OOM risk for very large speakers)" )
219
236
logger .debug (f"Expected speaker counts for model: { self .speaker_expected_counts } " )
220
237
221
238
def forward (self , batch ):
@@ -295,46 +312,80 @@ def _process_and_flush_speakers_local(self):
295
312
self .total_accumulated_items -= len (speaker_items )
296
313
self .processed_speakers_set .add (speaker_id )
297
314
298
- # NOTE: Potential OOM (Out Of Memory) risk if a single speaker has an extremely large
299
- # number of segments (e.g., tens of thousands). The N x N similarity matrix calculated below
300
- # (where N = len(speaker_items)) can consume significant CPU RAM.
301
- # For example, 50,000 segments for one speaker could lead to a float32 similarity matrix
302
- # requiring approximately 10 GB of RAM. Consider this if processing datasets with
303
- # speakers having a very high number of utterances.
304
- embeddings = torch .stack ([item ['embedding' ] for item in speaker_items ])
305
- embeddings_norm = torch .nn .functional .normalize (embeddings , p = 2 , dim = 1 )
306
- similarity_matrix = torch .matmul (embeddings_norm , embeddings_norm .transpose (0 , 1 ))
307
- similarity_matrix .fill_diagonal_ (- 2.0 ) # cosine similarity range is [-1, 1]
315
+ # Apply speaker size limit to prevent OOM while processing all items
316
+ all_items_to_process = speaker_items # We want to process ALL items
317
+
318
+ # Create context pool with original indices for easy identification
319
+ if self .max_speaker_items and len (speaker_items ) > self .max_speaker_items :
320
+ logger .warning (
321
+ f"Speaker { speaker_id } has { len (speaker_items )} items, exceeding max limit of { self .max_speaker_items } . "
322
+ f"Using random sample of { self .max_speaker_items } items as context pool, but processing all { len (speaker_items )} items."
323
+ )
324
+ # Randomly sample with original indices preserved
325
+ random .seed (12345 ) # For reproducibility
326
+ indexed_items = [(idx , item ) for idx , item in enumerate (speaker_items )]
327
+ sampled_indexed_items = random .sample (indexed_items , self .max_speaker_items )
328
+ context_pool_items = [item for _ , item in sampled_indexed_items ]
329
+ context_pool_original_indices = [idx for idx , _ in sampled_indexed_items ]
330
+ else :
331
+ # Use all items as context pool
332
+ context_pool_items = speaker_items
333
+ context_pool_original_indices = list (range (len (speaker_items )))
334
+
335
+ # NOTE: Now we compute similarities between ALL items and the context pool.
336
+ # This limits the similarity matrix to N×M instead of N×N where M <= max_speaker_items.
337
+ # Memory usage: N×M×4 bytes instead of N×N×4 bytes.
338
+ all_embeddings = torch .stack ([item ['embedding' ] for item in all_items_to_process ])
339
+ context_embeddings = torch .stack ([item ['embedding' ] for item in context_pool_items ])
340
+
341
+ all_embeddings_norm = torch .nn .functional .normalize (all_embeddings , p = 2 , dim = 1 )
342
+ context_embeddings_norm = torch .nn .functional .normalize (context_embeddings , p = 2 , dim = 1 )
343
+
344
+ # Compute N×M similarity matrix: each row is similarities for one item against all context candidates
345
+ similarity_matrix = torch .matmul (all_embeddings_norm , context_embeddings_norm .transpose (0 , 1 ))
346
+
347
+ # Mask positions where items are identical (same item appearing in both N and M sets)
348
+ # Using original indices as identifiers. This prevents an item from being selected as its own context.
349
+ # Create a mapping from original indices to context pool positions
350
+ original_index_to_context_position = {}
351
+ for context_pos , original_idx in enumerate (context_pool_original_indices ):
352
+ original_index_to_context_position [original_idx ] = context_pos
353
+
354
+ # Mask similarities for identical items
355
+ for n_idx in range (len (all_items_to_process )):
356
+ if n_idx in original_index_to_context_position :
357
+ context_pos = original_index_to_context_position [n_idx ]
358
+ similarity_matrix [n_idx , context_pos ] = MASKED_SIMILARITY_VALUE
308
359
309
360
# Sort all similarities for each item to iterate through candidates
310
- # best_similarities_tensor will contain sorted similarities for each row (original item)
311
- # best_indices_tensor will contain original indices of these sorted items
361
+ # sorted_similarities_tensor will contain sorted similarities for each row (original item)
362
+ # sorted_indices_tensor will contain indices in the context_pool
312
363
sorted_similarities_tensor , sorted_indices_tensor = torch .sort (similarity_matrix , dim = 1 , descending = True )
313
364
314
- record_preparation_start_time = time .time ()
315
365
num_records_written_for_speaker = 0
316
366
# Initialize a counter for items discarded for this specific speaker
317
367
num_discarded_for_this_speaker_no_context = 0
318
368
319
- for i , current_item_data in enumerate (speaker_items ):
369
+ for i , current_item_data in enumerate (all_items_to_process ):
320
370
output_record = current_item_data ['metadata' ].copy ()
321
371
write_this_record = False
322
372
323
- # Iterate through potential candidates, sorted by similarity
373
+ # Iterate through potential candidates from context pool , sorted by similarity
324
374
for candidate_rank in range (sorted_indices_tensor .size (1 )):
325
375
candidate_ssim = sorted_similarities_tensor [i , candidate_rank ].item ()
326
- original_candidate_idx = sorted_indices_tensor [i , candidate_rank ].item ()
376
+ context_pool_idx = sorted_indices_tensor [i , candidate_rank ].item ()
327
377
328
- # Skip if candidate is the item itself (safeguard)
329
- if original_candidate_idx == i :
330
- continue
378
+ # if ANY candidate has similarity ≤ MASKED_SIMILARITY_VALUE, all subsequent ones will be ≤ MASKED_SIMILARITY_VALUE
379
+ # since similarities are sorted in descending order, we can break early
380
+ if candidate_ssim <= MASKED_SIMILARITY_VALUE :
381
+ break
331
382
332
- # If SSIM is below threshold, stop searching for this item (since candidates are sorted)
383
+ # If SSIM is below threshold, stop searching for this item
333
384
if candidate_ssim < self .context_min_ssim :
334
385
break
335
386
336
387
# Check duration if SSIM is acceptable
337
- best_meta_dict = speaker_items [ original_candidate_idx ]['metadata' ]
388
+ best_meta_dict = context_pool_items [ context_pool_idx ]['metadata' ]
338
389
candidate_duration = best_meta_dict ["duration" ]
339
390
340
391
if candidate_duration >= self .context_min_duration :
@@ -563,6 +614,12 @@ def main():
563
614
parser .add_argument (
564
615
"--context-min-ssim" , type = float , default = 0.6 , help = "Minimum cosine similarity for a context audio segment."
565
616
)
617
+ parser .add_argument (
618
+ "--max-speaker-items" ,
619
+ type = int ,
620
+ default = None ,
621
+ help = "Maximum size of context pool per speaker to prevent OOM. If a speaker has more items, a random sample will be used as context pool, but all items will still be processed. Default: None (no limit, potential OOM risk)." ,
622
+ )
566
623
parser .add_argument ("--devices" , type = int , default = - 1 )
567
624
parser .add_argument ("--num-nodes" , type = int , default = 1 )
568
625
parser .add_argument ("--batch-size" , type = int , default = 16 )
@@ -837,6 +894,7 @@ def main():
837
894
context_min_ssim = args .context_min_ssim ,
838
895
speaker_expected_counts_map = my_speaker_expected_counts ,
839
896
initial_assigned_count = len (assigned_records_for_this_rank ),
897
+ max_speaker_items = args .max_speaker_items ,
840
898
)
841
899
logger .info (
842
900
f"Starting prediction with { len (assigned_records_for_this_rank )} records ({ len (my_speaker_expected_counts )} unique speakers for this rank according to counts)."
0 commit comments