Skip to content

Commit 0ad7bcb

Browse files
authored
Docstrings and comments on EOS handling. (#14847)
Docstrings and comments This is a follow-up to: #14761 Also, changed the default value of forbid_audio_eos to False in the get_forbidden_tokens function since permitting EOS is the more common use case. The default value is not was not being used anywhere so this doesn't change behavior. Signed-off-by: Fejgin, Roy <[email protected]> * Add docstrings for sampling methods Signed-off-by: Fejgin, Roy <[email protected]> * Fix isort issue Signed-off-by: Fejgin, Roy <[email protected]> --------- Signed-off-by: Fejgin, Roy <[email protected]>
1 parent 0b48537 commit 0ad7bcb

File tree

2 files changed

+168
-43
lines changed

2 files changed

+168
-43
lines changed

nemo/collections/tts/models/magpietts.py

Lines changed: 166 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
import time
1818
from functools import partial
19-
from typing import List, Union
19+
from typing import Dict, List, Optional, Union
2020

2121
import numpy as np
2222
import soundfile as sf
@@ -757,12 +757,15 @@ def code_to_str(code):
757757
output_str += c
758758
logging.debug(output_str)
759759

760-
def clear_forbidden_logits(self, logits, forbid_audio_eos=False):
760+
def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor:
761761
"""
762762
Sets logits of forbidden tokens to `-inf` so they will never be sampled.
763-
Specifically, we forbid sampling of all special tokens except AUDIO_EOS.
763+
Specifically, we forbid sampling of all special tokens except AUDIO_EOS
764+
which is allowed by default.
764765
Args:
765766
logits: (B, C, num_audio_tokens_per_codebook)
767+
forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens
768+
from being sampled. Default: False.
766769
"""
767770
logits[
768771
:,
@@ -773,22 +776,77 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False):
773776

774777
def local_transformer_sample_maskgit(
775778
self,
776-
dec_output,
777-
temperature=0.7,
778-
topk=80,
779-
unfinished_items={},
780-
finished_items={},
781-
use_cfg=False,
782-
cfg_scale=1.0,
783-
n_steps=3,
784-
noise_scale=0.0,
785-
fixed_schedule=None,
786-
dynamic_cfg_scale=False,
787-
sampling_type=None,
788-
forbid_audio_eos=False,
789-
):
779+
dec_output: torch.Tensor,
780+
temperature: float = 0.7,
781+
topk: int = 80,
782+
unfinished_items: Dict[int, bool] = {},
783+
finished_items: Dict[int, bool] = {},
784+
use_cfg: bool = False,
785+
cfg_scale: float = 1.0,
786+
n_steps: int = 3,
787+
noise_scale: float = 0.0,
788+
fixed_schedule: Optional[List[int]] = None,
789+
dynamic_cfg_scale: bool = False,
790+
sampling_type: Optional[str] = None,
791+
forbid_audio_eos: bool = False,
792+
) -> torch.Tensor:
790793
"""
791-
Sample codes for one timestep from the local transformer using MaskGit.
794+
Sample audio codes for the current timestep using MaskGit-like iterative
795+
prediction with the local transformer. If frame-stacking is enabled, the
796+
codes for all frames in the stack are sampled, treated as one long sequence.
797+
798+
The MaskGit process starts with all positions masked and iteratively unmasks the
799+
most confident positions over multiple steps. By "masked" we mean that a
800+
dedicated MASK token is used (as opposed to attention masking). The LT in this
801+
case is a non-causal transformer decoder. At each step the model predicts all
802+
positions at once. Of those predictions, a subset of the most confident
803+
previously-masked positions is kept and unmasked in the next step. The number of
804+
positions that are unmasked at each step is determined by the unmasking
805+
schedule. We support a cosine schedule and a fixed schedule provided by the
806+
user.
807+
808+
Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).
809+
810+
Special handling:
811+
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
812+
* forces / forbids EOS for finished / unfinished items respectively
813+
* optionally, globally forbids audio EOS for all items in the batch.
814+
This is useful early in the generation process.
815+
* supports different unmasking methods, see `sampling_type` argument for details.
816+
817+
Args:
818+
dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
819+
and E is primary decoder's embedding dimension.
820+
temperature (float, optional): Sampling temperature
821+
topk (int, optional): Number of top-probability tokens to consider in sampling.
822+
unfinished_items (dict, optional): Dictionary containing indices of batch
823+
items that we are confident have not completed generation. For these items, audio EOS
824+
sampling is forbidden.
825+
finished_items (dict, optional): Dictionary containing indices of batch
826+
items that we are confident are completed. For these items, audio EOS sampling
827+
is forced.
828+
use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
829+
to be doubled with conditional and unconditional outputs from the primary decoder.
830+
cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
831+
n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling.
832+
noise_scale (float, optional): Scale factor for noise to add to confidence scores
833+
during sampling (experimental).
834+
fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step.
835+
If None, uses cosine schedule.
836+
dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during
837+
sampling (experimental).
838+
sampling_type (str, optional): Type of sampling strategy. Options are:
839+
["default", "causal", "purity_causal", "purity_default"].
840+
* Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity"
841+
is not specified, confidence sampling is used as in the original MaskGit paper.
842+
* "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled.
843+
If "causal" is specified, frames are unmasked in causal order. "default"
844+
doesn't impose any constraints on the unmasking order.
845+
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
846+
batch.
847+
848+
Returns:
849+
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
792850
"""
793851
# dec_output: (B, E)
794852
device = dec_output.device
@@ -893,7 +951,7 @@ def local_transformer_sample_maskgit(
893951
cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits
894952
logits[:actual_batch_size] = cfg_logits
895953

896-
# Disallow generation of special tokens (except audio EOS which is handled separately)
954+
# Disallow generation of special tokens
897955
logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos)
898956

899957
# handle unfinished and finished items
@@ -955,17 +1013,56 @@ def local_transformer_sample_maskgit(
9551013

9561014
def local_transformer_sample_autoregressive(
9571015
self,
958-
dec_output,
959-
temperature=0.7,
960-
topk=80,
961-
unfinished_items={},
962-
finished_items={},
963-
use_cfg=False,
964-
cfg_scale=1.0,
965-
use_kv_cache=True,
966-
forbid_audio_eos=False,
967-
):
968-
# dec_output: (B, E)
1016+
dec_output: torch.Tensor,
1017+
temperature: float = 0.7,
1018+
topk: int = 80,
1019+
unfinished_items: Dict[int, bool] = {},
1020+
finished_items: Dict[int, bool] = {},
1021+
use_cfg: bool = False,
1022+
cfg_scale: float = 1.0,
1023+
use_kv_cache: bool = True,
1024+
forbid_audio_eos: bool = False,
1025+
) -> torch.Tensor:
1026+
"""
1027+
Sample audio codes autoregressively across codebooks using the local
1028+
transformer. Uses multinomial sampling with temperature, top-k, and
1029+
classifier-free guidance (CFG).
1030+
1031+
The sequence is initialized with the primary decoder's hidden output as the only
1032+
input and is gradually extended a code for one codebook at a time, appending the
1033+
sampled code as input sequence for the next step. At the last step the sequence
1034+
is `num_codebooks` long. If frame stacking is enabled, codes for all frames in
1035+
the stack are sampled as one long sequence and the final sequence length is
1036+
`num_codebooks * frame_stacking_factor` codes long.
1037+
1038+
Special handling:
1039+
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
1040+
* forces / forbids EOS for finished / unfinished items respectively
1041+
* optionally, globally forbids audio EOS (useful early in the generation process)
1042+
1043+
Args:
1044+
dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
1045+
and E is primary decoder's embedding dimension.
1046+
temperature (float, optional): Sampling temperature.
1047+
topk (int, optional): Number of top-probability tokens to consider in sampling.
1048+
unfinished_items (dict, optional): Dictionary containing indices of batch
1049+
items that we are confident have not completed generation. For these items, audio EOS
1050+
sampling is forbidden.
1051+
finished_items (dict, optional): Dictionary containing indices of batch
1052+
items that we are confident are completed. For these items, audio EOS sampling
1053+
is forced.
1054+
use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
1055+
to be doubled with conditional and unconditional outputs from the primary decoder.
1056+
cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
1057+
use_kv_cache (bool, optional): Whether to use key-value caching in the transformer.
1058+
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
1059+
batch.
1060+
1061+
Returns:
1062+
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
1063+
where B is batch size (or actual_batch_size if use_cfg=True).
1064+
"""
1065+
9691066
self.local_transformer.reset_cache(use_cache=use_kv_cache)
9701067
dec_output = dec_output.unsqueeze(1) # (B, 1, E)
9711068
local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128)
@@ -991,9 +1088,11 @@ def local_transformer_sample_autoregressive(
9911088
codebook_logits[item_idx, :] = float('-inf')
9921089
codebook_logits[item_idx, self.audio_eos_id] = 0.0
9931090

1091+
# Disallow generation of special tokens
9941092
codebook_logits = self.clear_forbidden_logits(
9951093
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
9961094
).squeeze(1)
1095+
9971096
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
9981097
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
9991098
-1
@@ -1028,14 +1127,40 @@ def local_transformer_sample_autoregressive(
10281127

10291128
def sample_codes_from_logits(
10301129
self,
1031-
all_code_logits_t,
1032-
temperature=0.7,
1033-
topk=80,
1034-
unfinished_items={},
1035-
finished_items={},
1036-
forbid_audio_eos=False,
1037-
):
1038-
# all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep
1130+
all_code_logits_t: torch.Tensor,
1131+
temperature: float = 0.7,
1132+
topk: int = 80,
1133+
unfinished_items: Dict[int, bool] = {},
1134+
finished_items: Dict[int, bool] = {},
1135+
forbid_audio_eos: bool = False,
1136+
) -> torch.Tensor:
1137+
"""
1138+
Sample codes for all codebooks at a given timestep. Uses multinomial sampling
1139+
with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor
1140+
> 1`), this function will sample across the entire frame stack.
1141+
1142+
Special handling:
1143+
* forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
1144+
* forces / forbids EOS for finished / unfinished items respectively
1145+
* optionally, globally forbids audio EOS (useful early in the generation process)
1146+
1147+
Args:
1148+
all_code_logits_t (torch.Tensor): Logits at a given timestep with shape
1149+
(B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor)
1150+
temperature (float, optional): Sampling temperature
1151+
topk (int, optional): Number of top-probability tokens to consider in sampling.
1152+
unfinished_items (dict, optional): Dictionary containing indices of batch
1153+
items that we are confident have not completed generation. For these items, audio EOS
1154+
sampling is forbidden.
1155+
finished_items (dict, optional): Dictionary containing indices of batch
1156+
items that we are confident are completed. For these items, audio EOS sampling
1157+
is forced.
1158+
forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
1159+
batch.
1160+
1161+
Returns:
1162+
torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor).
1163+
"""
10391164
all_preds = [[] for _ in range(self.frame_stacking_factor)]
10401165
for fs_index in range(self.frame_stacking_factor):
10411166
for idx in range(self.num_audio_codebooks):
@@ -1048,9 +1173,12 @@ def sample_codes_from_logits(
10481173
for item_idx in finished_items:
10491174
codebook_logits[item_idx, :] = float('-inf')
10501175
codebook_logits[item_idx, self.audio_eos_id] = 0.0
1176+
1177+
# Disallow generation of special tokens
10511178
codebook_logits = self.clear_forbidden_logits(
10521179
codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos
10531180
).squeeze(1)
1181+
10541182
codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk)
10551183
indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze(
10561184
-1

nemo/collections/tts/modules/magpietts_modules.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,12 @@ def get_index(token: SpecialAudioToken, base_codebook_size: int):
103103
return base_codebook_size + token.value
104104

105105
@staticmethod
106-
def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = True) -> list[int]:
106+
def get_forbidden_tokens(base_codebook_size: int, forbid_audio_eos: bool = False) -> list[int]:
107107
"""
108108
Returns a list of token indices that should not be sampled or returned to user.
109109
Args:
110110
base_codebook_size (int): The size of the codec codebook (which is the first part of the embedding table).
111-
forbid_audio_eos (bool): Whether to forbid the AUDIO_EOS token to be sampled.
112-
* Set to `False` when internally generating tokens in MagpieTTS sampling
113-
* Set to `True` when checking validity of tokens to be returned to user
114-
or given to the codec for decoding
111+
forbid_audio_eos (bool): Whether AUDIO_EOS should be forbidden. Default: False (i.e. allowed).
115112
"""
116113
all_special_tokens = list(SpecialAudioToken)
117114
if not forbid_audio_eos:

0 commit comments

Comments
 (0)