diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 973e75fb9f33..c719d5f9b1ab 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -734,7 +734,7 @@ def code_to_str(code): output_str += c logging.debug(output_str) - def clear_forbidden_logits(self, logits): + def clear_forbidden_logits(self, logits, forbid_audio_eos=False): """ Sets logits of forbidden tokens to `-inf` so they will never be sampled. Specifically, we forbid sampling of all special tokens except AUDIO_EOS. @@ -742,7 +742,9 @@ def clear_forbidden_logits(self, logits): logits: (B, C, num_audio_tokens_per_codebook) """ logits[ - :, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False) + :, + :, + SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits @@ -760,6 +762,7 @@ def local_transformer_sample_maskgit( fixed_schedule=None, dynamic_cfg_scale=False, sampling_type=None, + forbid_audio_eos=False, ): """ Sample codes for one timestep from the local transformer using MaskGit. @@ -868,7 +871,7 @@ def local_transformer_sample_maskgit( logits[:actual_batch_size] = cfg_logits # Disallow generation of special tokens (except audio EOS which is handled separately) - logits = self.clear_forbidden_logits(logits) + logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos) # handle unfinished and finished items for item_idx in unfinished_items: @@ -937,6 +940,7 @@ def local_transformer_sample_autoregressive( use_cfg=False, cfg_scale=1.0, use_kv_cache=True, + forbid_audio_eos=False, ): # dec_output: (B, E) self.local_transformer.reset_cache(use_cache=use_kv_cache) @@ -964,7 +968,9 @@ def local_transformer_sample_autoregressive( codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -998,7 +1004,13 @@ def local_transformer_sample_autoregressive( return all_preds def sample_codes_from_logits( - self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={} + self, + all_code_logits_t, + temperature=0.7, + topk=80, + unfinished_items={}, + finished_items={}, + forbid_audio_eos=False, ): # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep all_preds = [[] for _ in range(self.frame_stacking_factor)] @@ -1013,7 +1025,9 @@ def sample_codes_from_logits( for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 - codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1) + codebook_logits = self.clear_forbidden_logits( + codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos + ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 @@ -2090,6 +2104,9 @@ def infer_batch( maskgit_sampling_type=None, ignore_finished_sentence_tracking=False, eos_detection_method="argmax_or_multinomial_any", + # Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 + # lines up with the codec's minimum frame requirement. + min_generated_frames=4, ): eos_detection_method = EOSDetectionMethod(eos_detection_method) with torch.no_grad(): @@ -2257,6 +2274,10 @@ def infer_batch( } # Items that have been close to the end for atleast 20 timesteps unfinished_items = {k: v for k, v in unfinished_texts.items() if v} + # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) + # This guards against rare cases of termination right at the start of generation. + forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: if self.local_transformer_type == LocalTransformerType.AR: @@ -2270,6 +2291,7 @@ def infer_batch( use_cfg=use_cfg, cfg_scale=cfg_scale, use_kv_cache=use_LT_kv_cache, + forbid_audio_eos=forbid_audio_eos, ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( @@ -2285,6 +2307,7 @@ def infer_batch( fixed_schedule=maskgit_fixed_schedule, dynamic_cfg_scale=maskgit_dynamic_cfg_scale, sampling_type=maskgit_sampling_type, + forbid_audio_eos=forbid_audio_eos, ) else: raise ValueError( @@ -2298,12 +2321,14 @@ def infer_batch( topk=topk, unfinished_items=unfinished_items, finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) all_codes_next_argmax = self.sample_codes_from_logits( all_code_logits_t, temperature=0.01, unfinished_items=unfinished_items, finished_items=finished_items, + forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)):