16
16
import random
17
17
import time
18
18
from functools import partial
19
- from typing import List , Union
19
+ from typing import Dict , List , Optional , Union
20
20
21
21
import numpy as np
22
22
import soundfile as sf
@@ -757,12 +757,15 @@ def code_to_str(code):
757
757
output_str += c
758
758
logging .debug (output_str )
759
759
760
- def clear_forbidden_logits (self , logits , forbid_audio_eos = False ):
760
+ def clear_forbidden_logits (self , logits : torch . Tensor , forbid_audio_eos : bool = False ) -> torch . Tensor :
761
761
"""
762
762
Sets logits of forbidden tokens to `-inf` so they will never be sampled.
763
- Specifically, we forbid sampling of all special tokens except AUDIO_EOS.
763
+ Specifically, we forbid sampling of all special tokens except AUDIO_EOS
764
+ which is allowed by default.
764
765
Args:
765
766
logits: (B, C, num_audio_tokens_per_codebook)
767
+ forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens
768
+ from being sampled. Default: False.
766
769
"""
767
770
logits [
768
771
:,
@@ -773,22 +776,77 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False):
773
776
774
777
def local_transformer_sample_maskgit (
775
778
self ,
776
- dec_output ,
777
- temperature = 0.7 ,
778
- topk = 80 ,
779
- unfinished_items = {},
780
- finished_items = {},
781
- use_cfg = False ,
782
- cfg_scale = 1.0 ,
783
- n_steps = 3 ,
784
- noise_scale = 0.0 ,
785
- fixed_schedule = None ,
786
- dynamic_cfg_scale = False ,
787
- sampling_type = None ,
788
- forbid_audio_eos = False ,
789
- ):
779
+ dec_output : torch . Tensor ,
780
+ temperature : float = 0.7 ,
781
+ topk : int = 80 ,
782
+ unfinished_items : Dict [ int , bool ] = {},
783
+ finished_items : Dict [ int , bool ] = {},
784
+ use_cfg : bool = False ,
785
+ cfg_scale : float = 1.0 ,
786
+ n_steps : int = 3 ,
787
+ noise_scale : float = 0.0 ,
788
+ fixed_schedule : Optional [ List [ int ]] = None ,
789
+ dynamic_cfg_scale : bool = False ,
790
+ sampling_type : Optional [ str ] = None ,
791
+ forbid_audio_eos : bool = False ,
792
+ ) -> torch . Tensor :
790
793
"""
791
- Sample codes for one timestep from the local transformer using MaskGit.
794
+ Sample audio codes for the current timestep using MaskGit-like iterative
795
+ prediction with the local transformer. If frame-stacking is enabled, the
796
+ codes for all frames in the stack are sampled, treated as one long sequence.
797
+
798
+ The MaskGit process starts with all positions masked and iteratively unmasks the
799
+ most confident positions over multiple steps. By "masked" we mean that a
800
+ dedicated MASK token is used (as opposed to attention masking). The LT in this
801
+ case is a non-causal transformer decoder. At each step the model predicts all
802
+ positions at once. Of those predictions, a subset of the most confident
803
+ previously-masked positions is kept and unmasked in the next step. The number of
804
+ positions that are unmasked at each step is determined by the unmasking
805
+ schedule. We support a cosine schedule and a fixed schedule provided by the
806
+ user.
807
+
808
+ Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG).
809
+
810
+ Special handling:
811
+ * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
812
+ * forces / forbids EOS for finished / unfinished items respectively
813
+ * optionally, globally forbids audio EOS for all items in the batch.
814
+ This is useful early in the generation process.
815
+ * supports different unmasking methods, see `sampling_type` argument for details.
816
+
817
+ Args:
818
+ dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
819
+ and E is primary decoder's embedding dimension.
820
+ temperature (float, optional): Sampling temperature
821
+ topk (int, optional): Number of top-probability tokens to consider in sampling.
822
+ unfinished_items (dict, optional): Dictionary containing indices of batch
823
+ items that we are confident have not completed generation. For these items, audio EOS
824
+ sampling is forbidden.
825
+ finished_items (dict, optional): Dictionary containing indices of batch
826
+ items that we are confident are completed. For these items, audio EOS sampling
827
+ is forced.
828
+ use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
829
+ to be doubled with conditional and unconditional outputs from the primary decoder.
830
+ cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
831
+ n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling.
832
+ noise_scale (float, optional): Scale factor for noise to add to confidence scores
833
+ during sampling (experimental).
834
+ fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step.
835
+ If None, uses cosine schedule.
836
+ dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during
837
+ sampling (experimental).
838
+ sampling_type (str, optional): Type of sampling strategy. Options are:
839
+ ["default", "causal", "purity_causal", "purity_default"].
840
+ * Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity"
841
+ is not specified, confidence sampling is used as in the original MaskGit paper.
842
+ * "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled.
843
+ If "causal" is specified, frames are unmasked in causal order. "default"
844
+ doesn't impose any constraints on the unmasking order.
845
+ forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
846
+ batch.
847
+
848
+ Returns:
849
+ torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
792
850
"""
793
851
# dec_output: (B, E)
794
852
device = dec_output .device
@@ -893,7 +951,7 @@ def local_transformer_sample_maskgit(
893
951
cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale ) * unconditional_logits
894
952
logits [:actual_batch_size ] = cfg_logits
895
953
896
- # Disallow generation of special tokens (except audio EOS which is handled separately)
954
+ # Disallow generation of special tokens
897
955
logits = self .clear_forbidden_logits (logits , forbid_audio_eos = forbid_audio_eos )
898
956
899
957
# handle unfinished and finished items
@@ -955,17 +1013,56 @@ def local_transformer_sample_maskgit(
955
1013
956
1014
def local_transformer_sample_autoregressive (
957
1015
self ,
958
- dec_output ,
959
- temperature = 0.7 ,
960
- topk = 80 ,
961
- unfinished_items = {},
962
- finished_items = {},
963
- use_cfg = False ,
964
- cfg_scale = 1.0 ,
965
- use_kv_cache = True ,
966
- forbid_audio_eos = False ,
967
- ):
968
- # dec_output: (B, E)
1016
+ dec_output : torch .Tensor ,
1017
+ temperature : float = 0.7 ,
1018
+ topk : int = 80 ,
1019
+ unfinished_items : Dict [int , bool ] = {},
1020
+ finished_items : Dict [int , bool ] = {},
1021
+ use_cfg : bool = False ,
1022
+ cfg_scale : float = 1.0 ,
1023
+ use_kv_cache : bool = True ,
1024
+ forbid_audio_eos : bool = False ,
1025
+ ) -> torch .Tensor :
1026
+ """
1027
+ Sample audio codes autoregressively across codebooks using the local
1028
+ transformer. Uses multinomial sampling with temperature, top-k, and
1029
+ classifier-free guidance (CFG).
1030
+
1031
+ The sequence is initialized with the primary decoder's hidden output as the only
1032
+ input and is gradually extended a code for one codebook at a time, appending the
1033
+ sampled code as input sequence for the next step. At the last step the sequence
1034
+ is `num_codebooks` long. If frame stacking is enabled, codes for all frames in
1035
+ the stack are sampled as one long sequence and the final sequence length is
1036
+ `num_codebooks * frame_stacking_factor` codes long.
1037
+
1038
+ Special handling:
1039
+ * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
1040
+ * forces / forbids EOS for finished / unfinished items respectively
1041
+ * optionally, globally forbids audio EOS (useful early in the generation process)
1042
+
1043
+ Args:
1044
+ dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size
1045
+ and E is primary decoder's embedding dimension.
1046
+ temperature (float, optional): Sampling temperature.
1047
+ topk (int, optional): Number of top-probability tokens to consider in sampling.
1048
+ unfinished_items (dict, optional): Dictionary containing indices of batch
1049
+ items that we are confident have not completed generation. For these items, audio EOS
1050
+ sampling is forbidden.
1051
+ finished_items (dict, optional): Dictionary containing indices of batch
1052
+ items that we are confident are completed. For these items, audio EOS sampling
1053
+ is forced.
1054
+ use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size
1055
+ to be doubled with conditional and unconditional outputs from the primary decoder.
1056
+ cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True.
1057
+ use_kv_cache (bool, optional): Whether to use key-value caching in the transformer.
1058
+ forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
1059
+ batch.
1060
+
1061
+ Returns:
1062
+ torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor)
1063
+ where B is batch size (or actual_batch_size if use_cfg=True).
1064
+ """
1065
+
969
1066
self .local_transformer .reset_cache (use_cache = use_kv_cache )
970
1067
dec_output = dec_output .unsqueeze (1 ) # (B, 1, E)
971
1068
local_transformer_input = self .local_transformer_in_projection (dec_output ) # (B, 1, 128)
@@ -991,9 +1088,11 @@ def local_transformer_sample_autoregressive(
991
1088
codebook_logits [item_idx , :] = float ('-inf' )
992
1089
codebook_logits [item_idx , self .audio_eos_id ] = 0.0
993
1090
1091
+ # Disallow generation of special tokens
994
1092
codebook_logits = self .clear_forbidden_logits (
995
1093
codebook_logits .unsqueeze (1 ), forbid_audio_eos = forbid_audio_eos
996
1094
).squeeze (1 )
1095
+
997
1096
codebook_logits_topk = torch .topk (codebook_logits , topk , dim = - 1 )[0 ] # (B, topk)
998
1097
indices_to_remove = codebook_logits < codebook_logits_topk [:, - 1 ].unsqueeze (
999
1098
- 1
@@ -1028,14 +1127,40 @@ def local_transformer_sample_autoregressive(
1028
1127
1029
1128
def sample_codes_from_logits (
1030
1129
self ,
1031
- all_code_logits_t ,
1032
- temperature = 0.7 ,
1033
- topk = 80 ,
1034
- unfinished_items = {},
1035
- finished_items = {},
1036
- forbid_audio_eos = False ,
1037
- ):
1038
- # all_code_logits_t: (B, num_codebooks * num_tokens_per_codebook), logits at a given timestep
1130
+ all_code_logits_t : torch .Tensor ,
1131
+ temperature : float = 0.7 ,
1132
+ topk : int = 80 ,
1133
+ unfinished_items : Dict [int , bool ] = {},
1134
+ finished_items : Dict [int , bool ] = {},
1135
+ forbid_audio_eos : bool = False ,
1136
+ ) -> torch .Tensor :
1137
+ """
1138
+ Sample codes for all codebooks at a given timestep. Uses multinomial sampling
1139
+ with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor
1140
+ > 1`), this function will sample across the entire frame stack.
1141
+
1142
+ Special handling:
1143
+ * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled
1144
+ * forces / forbids EOS for finished / unfinished items respectively
1145
+ * optionally, globally forbids audio EOS (useful early in the generation process)
1146
+
1147
+ Args:
1148
+ all_code_logits_t (torch.Tensor): Logits at a given timestep with shape
1149
+ (B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor)
1150
+ temperature (float, optional): Sampling temperature
1151
+ topk (int, optional): Number of top-probability tokens to consider in sampling.
1152
+ unfinished_items (dict, optional): Dictionary containing indices of batch
1153
+ items that we are confident have not completed generation. For these items, audio EOS
1154
+ sampling is forbidden.
1155
+ finished_items (dict, optional): Dictionary containing indices of batch
1156
+ items that we are confident are completed. For these items, audio EOS sampling
1157
+ is forced.
1158
+ forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire
1159
+ batch.
1160
+
1161
+ Returns:
1162
+ torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor).
1163
+ """
1039
1164
all_preds = [[] for _ in range (self .frame_stacking_factor )]
1040
1165
for fs_index in range (self .frame_stacking_factor ):
1041
1166
for idx in range (self .num_audio_codebooks ):
@@ -1048,9 +1173,12 @@ def sample_codes_from_logits(
1048
1173
for item_idx in finished_items :
1049
1174
codebook_logits [item_idx , :] = float ('-inf' )
1050
1175
codebook_logits [item_idx , self .audio_eos_id ] = 0.0
1176
+
1177
+ # Disallow generation of special tokens
1051
1178
codebook_logits = self .clear_forbidden_logits (
1052
1179
codebook_logits .unsqueeze (1 ), forbid_audio_eos = forbid_audio_eos
1053
1180
).squeeze (1 )
1181
+
1054
1182
codebook_logits_topk = torch .topk (codebook_logits , topk , dim = - 1 )[0 ] # (B, topk)
1055
1183
indices_to_remove = codebook_logits < codebook_logits_topk [:, - 1 ].unsqueeze (
1056
1184
- 1
0 commit comments