Skip to content
37 changes: 31 additions & 6 deletions nemo/collections/tts/models/magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,15 +734,17 @@ def code_to_str(code):
output_str += c
logging.debug(output_str)

def clear_forbidden_logits(self, logits):
def clear_forbidden_logits(self, logits, forbid_audio_eos=False):
"""
Sets logits of forbidden tokens to `-inf` so they will never be sampled.
Specifically, we forbid sampling of all special tokens except AUDIO_EOS.
Args:
logits: (B, C, num_audio_tokens_per_codebook)
"""
logits[
:, :, SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=False)
:,
:,
SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos),
] = float('-inf')
return logits

Expand All @@ -760,6 +762,7 @@ def local_transformer_sample_maskgit(
fixed_schedule=None,
dynamic_cfg_scale=False,
sampling_type=None,
forbid_audio_eos=False,
):
"""
Sample codes for one timestep from the local transformer using MaskGit.
Expand Down Expand Up @@ -868,7 +871,7 @@ def local_transformer_sample_maskgit(
logits[:actual_batch_size] = cfg_logits

# Disallow generation of special tokens (except audio EOS which is handled separately)
logits = self.clear_forbidden_logits(logits)
logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos)

# handle unfinished and finished items
for item_idx in unfinished_items:
Expand Down Expand Up @@ -937,6 +940,7 @@ def local_transformer_sample_autoregressive(
use_cfg=False,
cfg_scale=1.0,
use_kv_cache=True,
forbid_audio_eos=False,
):
# dec_output: (B, E)
self.local_transformer.reset_cache(use_cache=use_kv_cache)
Expand Down Expand Up @@ -964,7 +968,9 @@ def local_transformer_sample_autoregressive(
codebook_logits[item_idx, :] = float('-inf')
codebook_logits[item_idx, self.audio_eos_id] = 0.0

codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1)
codebook_logits = self.clear_forbidden_logits(
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
).squeeze(1)
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
-1
Expand Down Expand Up @@ -998,7 +1004,13 @@ def local_transformer_sample_autoregressive(
return all_preds

def sample_codes_from_logits(
self, all_code_logits_t, temperature=0.7, topk=80, unfinished_items={}, finished_items={}
self,
all_code_logits_t,
temperature=0.7,
topk=80,
unfinished_items={},
finished_items={},
forbid_audio_eos=False,
):
# all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep
all_preds = [[] for _ in range(self.frame_stacking_factor)]
Expand All @@ -1013,7 +1025,9 @@ def sample_codes_from_logits(
for item_idx in finished_items:
codebook_logits[item_idx, :] = float('-inf')
codebook_logits[item_idx, self.audio_eos_id] = 0.0
codebook_logits = self.clear_forbidden_logits(codebook_logits.unsqueeze(1)).squeeze(1)
codebook_logits = self.clear_forbidden_logits(
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
).squeeze(1)
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
-1
Expand Down Expand Up @@ -2090,6 +2104,9 @@ def infer_batch(
maskgit_sampling_type=None,
ignore_finished_sentence_tracking=False,
eos_detection_method="argmax_or_multinomial_any",
# Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4
# lines up with the codec's minimum frame requirement.
min_generated_frames=4,
):
eos_detection_method = EOSDetectionMethod(eos_detection_method)
with torch.no_grad():
Expand Down Expand Up @@ -2257,6 +2274,10 @@ def infer_batch(
} # Items that have been close to the end for atleast 20 timesteps
unfinished_items = {k: v for k, v in unfinished_texts.items() if v}

# Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor)
# This guards against rare cases of termination right at the start of generation.
forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames

all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook)
if use_local_transformer_for_inference:
if self.local_transformer_type == LocalTransformerType.AR:
Expand All @@ -2270,6 +2291,7 @@ def infer_batch(
use_cfg=use_cfg,
cfg_scale=cfg_scale,
use_kv_cache=use_LT_kv_cache,
forbid_audio_eos=forbid_audio_eos,
)
elif self.local_transformer_type == LocalTransformerType.MASKGIT:
audio_codes_next = self.local_transformer_sample_maskgit(
Expand All @@ -2285,6 +2307,7 @@ def infer_batch(
fixed_schedule=maskgit_fixed_schedule,
dynamic_cfg_scale=maskgit_dynamic_cfg_scale,
sampling_type=maskgit_sampling_type,
forbid_audio_eos=forbid_audio_eos,
)
else:
raise ValueError(
Expand All @@ -2298,12 +2321,14 @@ def infer_batch(
topk=topk,
unfinished_items=unfinished_items,
finished_items=finished_items,
forbid_audio_eos=forbid_audio_eos,
) # (B, num_codebooks, frame_stacking_factor)
all_codes_next_argmax = self.sample_codes_from_logits(
all_code_logits_t,
temperature=0.01,
unfinished_items=unfinished_items,
finished_items=finished_items,
forbid_audio_eos=forbid_audio_eos,
) # (B, num_codebooks, frame_stacking_factor)

for item_idx in range(all_codes_next_argmax.size(0)):
Expand Down
Loading