From 1782d67a85c73faf68e0203fb08cfd2d7610748e Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 3 Sep 2025 18:00:38 -0700 Subject: [PATCH 01/11] Don't allow EOS until 4 frames have been generated The number of frames is configuratble via a parameter to infer_batch(). This is a workaround to the observation that when CFG is on we sometimes terminate after zero tokens. It appears to be an artifacts of CFG, since the EOS logit is not particularly large for the conditional logits; only post-CFG. Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 26 +++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index fb9a060e6561..138fc75ae185 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -718,7 +718,7 @@ def code_to_str(code): output_str += c logging.debug(output_str) - def clear_forbidden_logits(self, logits): + def clear_forbidden_logits(self, logits, forbid_audio_eos=False): """ Sets logits of forbidden tokens to `-inf` so they will never be sampled. Specifically, we forbid sampling of all special tokens except AUDIO_EOS. @@ -726,7 +726,9 @@ def clear_forbidden_logits(self, logits): logits: (B, C, num_audio_tokens_per_codebook) """ logits[ - :, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False) + :, + :, + SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits @@ -744,6 +746,7 @@ def local_transformer_sample_maskgit( fixed_schedule=None, dynamic_cfg_scale=False, sampling_type=None, + forbid_audio_eos=False, ): """ Sample codes for one timestep from the local transformer using MaskGit. @@ -852,7 +855,7 @@ def local_transformer_sample_maskgit( logits[:actual_batch_size] = cfg_logits # Disallow generation of special tokens (except audio EOS which is handled separately) - logits = self.clear_forbidden_logits(logits) + logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos) # handle unfinished and finished items for item_idx in unfinished_items: @@ -921,6 +924,7 @@ def local_transformer_sample_autoregressive( use_cfg=False, cfg_scale=1.0, use_kv_cache=True, + forbid_audio_eos=False ): # dec_output: (B, E) self.local_transformer.reset_cache(use_cache=use_kv_cache) @@ -948,7 +952,7 @@ def local_transformer_sample_autoregressive( codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) + codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -983,7 +987,7 @@ def local_transformer_sample_autoregressive( def sample_codes_from_logits( self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={} - ): + , forbid_audio_eos=False): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [[] for _ in range(self.frame_stacking_factor)] for fs_index in range(self.frame_stacking_factor): @@ -997,7 +1001,7 @@ def sample_codes_from_logits( for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) + codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -2037,6 +2041,7 @@ def infer_batch( maskgit_fixed_schedule=None, maskgit_dynamic_cfg_scale=False, maskgit_sampling_type=None, + min_generated_frames=4, ): with torch.no_grad(): start_time = time.time() @@ -2199,7 +2204,10 @@ def infer_batch( } # Items that have been close to the end for atleast 20 timesteps unfinished_items = {k: v for k, v in unfinished_texts.items() if v} - all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) + # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) + forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames + + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: if self.local_transformer_type == LocalTransformerType.AR: # Autoregressive sampling with local transformer @@ -2212,6 +2220,7 @@ def infer_batch( use_cfg=use_cfg, cfg_scale=cfg_scale, use_kv_cache=use_LT_kv_cache, + forbid_audio_eos=forbid_audio_eos ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( @@ -2227,6 +2236,7 @@ def infer_batch( fixed_schedule=maskgit_fixed_schedule, dynamic_cfg_scale=maskgit_dynamic_cfg_scale, sampling_type=maskgit_sampling_type, + forbid_audio_eos=forbid_audio_eos ) else: raise ValueError( @@ -2240,12 +2250,14 @@ def infer_batch( topk=topk, unfinished_items=unfinished_items, finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) all_codes_next_argmax = self.sample_codes_from_logits( all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)): From a0e58d0ca0173aca5d3fec5c863094ecb89b8a79 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Fri, 19 Sep 2025 11:11:50 -0700 Subject: [PATCH 02/11] Formatting Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 138fc75ae185..570a10cf08d3 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -726,9 +726,7 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False): logits: (B, C, num_audio_tokens_per_codebook) """ logits[ - :, - :, - SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), + :, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits @@ -986,8 +984,14 @@ def local_transformer_sample_autoregressive( return all_preds def sample_codes_from_logits( - self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={} - , forbid_audio_eos=False): + self, + all_code_logits_t, + temperature=0.7, + topk=80, + unfinished_items={}, + finished_items={}, + forbid_audio_eos=False, + ): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [[] for _ in range(self.frame_stacking_factor)] for fs_index in range(self.frame_stacking_factor): From 57df90c4e764e442c0a395bea2602d51d0bb71ea Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Fri, 19 Sep 2025 11:12:04 -0700 Subject: [PATCH 03/11] Command line option to set minimum number of frames to generate Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index b84f8e1a0a71..183c0cb87b23 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -293,6 +293,7 @@ def run_inference( log_exp_name=False, compute_fcd=False, violin_plot_metrics=['cer', 'pred_context_ssim'], + min_frames=0, ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -358,7 +359,8 @@ def run_inference( checkpoint_name += ( f"LT_{use_local_transformer}_" f"MaskGit_{maskgit_n_steps}_{maskgit_sampling_type}_{''.join([str(l) for l in maskgit_fixed_schedule]) if maskgit_fixed_schedule is not None else 'None'}_" - f"SV_{sv_model}" + f"SV_{sv_model}_" + f"MinFrames_{min_frames}" ) dataset_meta_info = evalset_config.dataset_meta_info @@ -488,6 +490,7 @@ def run_inference( maskgit_noise_scale=maskgit_noise_scale, maskgit_fixed_schedule=maskgit_fixed_schedule, maskgit_sampling_type=maskgit_sampling_type, + min_generated_frames=min_frames, ) all_rtf_metrics.append(rtf_metrics) @@ -675,6 +678,7 @@ def main(): default=['cer', 'pred_context_ssim'], help="Which metrics to add the violin plot.", ) + parser.add_argument('--min_frames', type=int, default=0, help="Minimum number of frames to generate") args = parser.parse_args() if args.datasets is None: @@ -722,6 +726,7 @@ def main(): log_exp_name=args.log_exp_name, compute_fcd=compute_fcd, violin_plot_metrics=args.violin_plot_metrics, + min_frames=args.min_frames, ) # Mode 1: Run inference from provided hparams and checkpoint files From 1005d1b3bfd6adc7007a3f161f9c24b854dec9b1 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Tue, 23 Sep 2025 20:10:05 -0700 Subject: [PATCH 04/11] formatting Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 570a10cf08d3..7066a114a59e 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -726,7 +726,9 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False): logits: (B, C, num_audio_tokens_per_codebook) """ logits[ - :, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), + :, + :, + SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits @@ -922,7 +924,7 @@ def local_transformer_sample_autoregressive( use_cfg=False, cfg_scale=1.0, use_kv_cache=True, - forbid_audio_eos=False + forbid_audio_eos=False, ): # dec_output: (B, E) self.local_transformer.reset_cache(use_cache=use_kv_cache) @@ -950,7 +952,9 @@ def local_transformer_sample_autoregressive( codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos).squeeze(1) + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -1005,7 +1009,9 @@ def sample_codes_from_logits( for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos).squeeze(1) + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -2211,7 +2217,7 @@ def infer_batch( # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames - all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: if self.local_transformer_type == LocalTransformerType.AR: # Autoregressive sampling with local transformer @@ -2224,7 +2230,7 @@ def infer_batch( use_cfg=use_cfg, cfg_scale=cfg_scale, use_kv_cache=use_LT_kv_cache, - forbid_audio_eos=forbid_audio_eos + forbid_audio_eos=forbid_audio_eos, ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( @@ -2240,7 +2246,7 @@ def infer_batch( fixed_schedule=maskgit_fixed_schedule, dynamic_cfg_scale=maskgit_dynamic_cfg_scale, sampling_type=maskgit_sampling_type, - forbid_audio_eos=forbid_audio_eos + forbid_audio_eos=forbid_audio_eos, ) else: raise ValueError( From b99620acb5ca51871f9bfd06f5b7aa9a8761f348 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 10:04:02 -0700 Subject: [PATCH 05/11] Show extreme values in violin plots (to aid in debugging rare issues) Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index d37630e037c4..1e8c6fe61fb8 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -149,7 +149,7 @@ def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: assert column in df # Create empty lists to store the parts objects for each DataFrame # Plot the violin plots for each DataFrame - axs[i].violinplot(df[column], showmedians=True, positions=[i], widths=0.5) + axs[i].violinplot(df[column], showmedians=True, positions=[i], widths=0.5, showextrema=True) axs[i].set_title(column) axs[i].set_xticks([i]) From 378a852727a9f6cf75527c4481a6f7544f754836 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 11:24:04 -0700 Subject: [PATCH 06/11] Fix merge issues Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 1e8c6fe61fb8..08b89a582631 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -293,9 +293,9 @@ def run_inference( log_exp_name=False, compute_fcd=False, violin_plot_metrics=['cer', 'pred_context_ssim'], - min_frames=0, eos_detection_method=None, ignore_finished_sentence_tracking=False, + min_frames=0, ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -362,6 +362,9 @@ def run_inference( f"LT_{use_local_transformer}_" f"MaskGit_{maskgit_n_steps}_{maskgit_sampling_type}_{''.join([str(l) for l in maskgit_fixed_schedule]) if maskgit_fixed_schedule is not None else 'None'}_" f"SV_{sv_model}" + f"EOS_{eos_detection_method}" + f"IgnoreFST_{ignore_finished_sentence_tracking}" + f"MinFrames_{min_frames}" ) dataset_meta_info = evalset_config.dataset_meta_info @@ -491,6 +494,9 @@ def run_inference( maskgit_noise_scale=maskgit_noise_scale, maskgit_fixed_schedule=maskgit_fixed_schedule, maskgit_sampling_type=maskgit_sampling_type, + ignore_finished_sentence_tracking=ignore_finished_sentence_tracking, + eos_detection_method=eos_detection_method, + min_frames=min_frames, ) all_rtf_metrics.append(rtf_metrics) @@ -740,6 +746,9 @@ def main(): log_exp_name=args.log_exp_name, compute_fcd=compute_fcd, violin_plot_metrics=args.violin_plot_metrics, + eos_detection_method=args.eos_detection_method, + ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, + min_frames=args.min_frames, ) # Mode 1: Run inference from provided hparams and checkpoint files From 279d7bdeaa65b113ecda7a43b0bae17b24d613e0 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 11:26:30 -0700 Subject: [PATCH 07/11] More merge fixes Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 2 +- scripts/magpietts/infer_and_evaluate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 237d70e7f9ee..389117ca3971 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -2102,9 +2102,9 @@ def infer_batch( maskgit_fixed_schedule=None, maskgit_dynamic_cfg_scale=False, maskgit_sampling_type=None, - min_generated_frames=4, ignore_finished_sentence_tracking=False, eos_detection_method="argmax_or_multinomial_any", + min_generated_frames=4, ): eos_detection_method = EOSDetectionMethod(eos_detection_method) with torch.no_grad(): diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 08b89a582631..074a8bb346d4 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -496,7 +496,7 @@ def run_inference( maskgit_sampling_type=maskgit_sampling_type, ignore_finished_sentence_tracking=ignore_finished_sentence_tracking, eos_detection_method=eos_detection_method, - min_frames=min_frames, + min_generated_frames=min_frames, ) all_rtf_metrics.append(rtf_metrics) From f9a68ac4bf9b79c4d2ce8249d3d8c4184172a0db Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 17:55:20 -0700 Subject: [PATCH 08/11] Remove temporary changes in infer_and_evaluate.py Signed-off-by: Fejgin, Roy --- scripts/magpietts/infer_and_evaluate.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/scripts/magpietts/infer_and_evaluate.py b/scripts/magpietts/infer_and_evaluate.py index 074a8bb346d4..a22c178bcec4 100755 --- a/scripts/magpietts/infer_and_evaluate.py +++ b/scripts/magpietts/infer_and_evaluate.py @@ -149,7 +149,7 @@ def create_violin_plots(metrics: List[dict], metric_keys: List[str], output_png: assert column in df # Create empty lists to store the parts objects for each DataFrame # Plot the violin plots for each DataFrame - axs[i].violinplot(df[column], showmedians=True, positions=[i], widths=0.5, showextrema=True) + axs[i].violinplot(df[column], showmedians=True, positions=[i], widths=0.5) axs[i].set_title(column) axs[i].set_xticks([i]) @@ -295,7 +295,6 @@ def run_inference( violin_plot_metrics=['cer', 'pred_context_ssim'], eos_detection_method=None, ignore_finished_sentence_tracking=False, - min_frames=0, ): # Load model if hparams_file is not None and checkpoint_file is not None: @@ -364,7 +363,6 @@ def run_inference( f"SV_{sv_model}" f"EOS_{eos_detection_method}" f"IgnoreFST_{ignore_finished_sentence_tracking}" - f"MinFrames_{min_frames}" ) dataset_meta_info = evalset_config.dataset_meta_info @@ -496,7 +494,6 @@ def run_inference( maskgit_sampling_type=maskgit_sampling_type, ignore_finished_sentence_tracking=ignore_finished_sentence_tracking, eos_detection_method=eos_detection_method, - min_generated_frames=min_frames, ) all_rtf_metrics.append(rtf_metrics) @@ -698,7 +695,6 @@ def main(): default=['cer', 'pred_context_ssim'], help="Which metrics to add the violin plot.", ) - parser.add_argument('--min_frames', type=int, default=0, help="Minimum number of frames to generate") args = parser.parse_args() if args.datasets is None: @@ -748,7 +744,6 @@ def main(): violin_plot_metrics=args.violin_plot_metrics, eos_detection_method=args.eos_detection_method, ignore_finished_sentence_tracking=args.ignore_finished_sentence_tracking, - min_frames=args.min_frames, ) # Mode 1: Run inference from provided hparams and checkpoint files From 4148282e5e92741de2c1833a6ef9a1fa55ea69b0 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 18:09:21 -0700 Subject: [PATCH 09/11] Comments Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 389117ca3971..8ea6ceb980e4 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -2273,6 +2273,7 @@ def infer_batch( unfinished_items = {k: v for k, v in unfinished_texts.items() if v} # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) + # This guards against rare cases of termination right at the start of generation. forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) From 2419155b1fb75c1602eab461f2a3ec95341b2bd1 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 18:19:51 -0700 Subject: [PATCH 10/11] Comments Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 8ea6ceb980e4..65f1cb2bd614 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -2104,6 +2104,8 @@ def infer_batch( maskgit_sampling_type=None, ignore_finished_sentence_tracking=False, eos_detection_method="argmax_or_multinomial_any", + # setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 + # lines up with the codec's minimum frame requirement. min_generated_frames=4, ): eos_detection_method = EOSDetectionMethod(eos_detection_method) From 699fb57393e9d2fb616f424f588c078e9d9f99ef Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 24 Sep 2025 18:21:32 -0700 Subject: [PATCH 11/11] Fix typo Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/magpietts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 65f1cb2bd614..c719d5f9b1ab 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -2104,7 +2104,7 @@ def infer_batch( maskgit_sampling_type=None, ignore_finished_sentence_tracking=False, eos_detection_method="argmax_or_multinomial_any", - # setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 + # Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 # lines up with the codec's minimum frame requirement. min_generated_frames=4, ):