36
36
from nemo .collections .tts .models import AudioCodecModel
37
37
from nemo .collections .tts .modules import transformer_2501
38
38
from nemo .collections .tts .modules .aligner import AlignmentEncoder
39
+ from nemo .collections .tts .modules .audio_codec_modules import VectorQuantizerIndexConverter
39
40
from nemo .collections .tts .modules .magpietts_modules import (
40
41
CharAwareSubwordEncoder ,
41
42
EOSDetectionMethod ,
@@ -95,17 +96,32 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
95
96
96
97
# load codec
97
98
codec_model = AudioCodecModel .restore_from (cfg .get ('codecmodel_path' ), strict = False )
99
+
98
100
self .sample_rate = codec_model .sample_rate
101
+ self .codec_model_samples_per_frame = codec_model .samples_per_frame
99
102
# del codec discriminator to free memory
100
103
del codec_model .discriminator
101
104
102
- # Set up codebook configuration
103
- self .num_audio_codebooks = codec_model .num_codebooks
104
- self .codec_model_samples_per_frame = codec_model .samples_per_frame
105
+ # When using FSQ tokens, the codebook structure can be changed at any time.
106
+ # An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure
107
+ # that is different than in the audio codec checkpoint.
108
+ vector_quantizer = cfg .get ('vector_quantizer' )
109
+ if vector_quantizer is not None :
110
+ vector_quantizer = instantiate (vector_quantizer )
111
+ self .num_audio_codebooks = vector_quantizer .num_codebooks
112
+ self .codebook_size = vector_quantizer .codebook_size
113
+ codec_converter = VectorQuantizerIndexConverter (
114
+ vector_quantizer_original = codec_model .vector_quantizer ,
115
+ vector_quantizer_new = vector_quantizer ,
116
+ )
117
+ else :
118
+ self .num_audio_codebooks = codec_model .num_codebooks
119
+ self .codebook_size = codec_model .codebook_size
120
+ codec_converter = None
121
+
105
122
# Our codebooks start with actual audio codec tokens, followed by special tokens.
106
123
# The `forced_*` options are for backward compatibility for models trained with older code.
107
- num_audio_tokens = codec_model .codebook_size
108
- get_token_index = partial (SpecialAudioToken .get_index , base_codebook_size = num_audio_tokens )
124
+ get_token_index = partial (SpecialAudioToken .get_index , base_codebook_size = self .codebook_size )
109
125
self .audio_bos_id = cfg .get ('forced_audio_bos_id' , get_token_index (SpecialAudioToken .AUDIO_BOS ))
110
126
self .audio_eos_id = cfg .get ('forced_audio_eos_id' , get_token_index (SpecialAudioToken .AUDIO_EOS ))
111
127
self .context_audio_bos_id = cfg .get (
@@ -116,7 +132,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
116
132
)
117
133
self .mask_token_id = cfg .get ('forced_mask_token_id' , get_token_index (SpecialAudioToken .MASK_TOKEN ))
118
134
self .num_all_tokens_per_codebook = cfg .get (
119
- 'forced_num_all_tokens_per_codebook' , num_audio_tokens + len (SpecialAudioToken )
135
+ 'forced_num_all_tokens_per_codebook' , self . codebook_size + len (SpecialAudioToken )
120
136
)
121
137
self .use_bpe_char_tokenizer = cfg .get ('use_bpe_char_tokenizer' , False )
122
138
@@ -201,6 +217,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
201
217
# This needs to happen after super().__init__()
202
218
self ._codec_model = codec_model
203
219
self ._codec_model .freeze () # Lightning does requires_grad = False and self.eval()
220
+ self ._codec_converter = codec_converter
204
221
205
222
audio_embeddings = []
206
223
for _ in range (self .num_audio_codebooks * self .frame_stacking_factor ):
@@ -450,6 +467,8 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'):
450
467
self ._codec_model .eval ()
451
468
with torch .no_grad (), torch .autocast (device_type = audio .device .type , dtype = torch .float32 ):
452
469
codes , codes_len = self ._codec_model .encode (audio = audio , audio_len = audio_len )
470
+ if self ._codec_converter is not None :
471
+ codes = self ._codec_converter .convert_original_to_new (audio_tokens = codes , audio_lens = codes_len )
453
472
# Add a timestep to begining and end of codes tensor
454
473
bos_tensor = torch .full (
455
474
(codes .size (0 ), codes .size (1 ), 1 ), audio_bos_id , dtype = codes .dtype , device = codes .device
@@ -478,6 +497,10 @@ def codes_to_audio(self, codes, codes_len):
478
497
codes_copy [codes == self .audio_bos_id ] = 0 # zero is the padding token
479
498
codes_copy [codes == self .audio_eos_id ] = 0
480
499
# Pass the modified integer token IDs
500
+ if self ._codec_converter is not None :
501
+ codes_copy = self ._codec_converter .convert_new_to_original (
502
+ audio_tokens = codes_copy , audio_lens = codes_len
503
+ )
481
504
audio , audio_len = self ._codec_model .decode (tokens = codes_copy , tokens_len = codes_len )
482
505
# audio: (B, T)
483
506
# audio_len: (B,)
@@ -744,7 +767,7 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False):
744
767
logits [
745
768
:,
746
769
:,
747
- SpecialAudioToken .get_forbidden_tokens (self ._codec_model . codebook_size , forbid_audio_eos = forbid_audio_eos ),
770
+ SpecialAudioToken .get_forbidden_tokens (self .codebook_size , forbid_audio_eos = forbid_audio_eos ),
748
771
] = float ('-inf' )
749
772
return logits
750
773
@@ -1276,6 +1299,10 @@ def prepare_context_tensors(self, batch):
1276
1299
if 'context_audio_codes' in batch :
1277
1300
context_audio_codes = batch ['context_audio_codes' ]
1278
1301
context_audio_codes_lens = batch ['context_audio_codes_lens' ]
1302
+ if self ._codec_converter is not None :
1303
+ context_audio_codes = self ._codec_converter .convert_original_to_new (
1304
+ audio_tokens = context_audio_codes , audio_lens = context_audio_codes_lens
1305
+ ).long ()
1279
1306
else :
1280
1307
context_audio_codes , context_audio_codes_lens = self .audio_to_codes (
1281
1308
batch ['context_audio' ], batch ['context_audio_lens' ], audio_type = 'context'
@@ -1498,6 +1525,10 @@ def process_batch(self, batch, mode="train"):
1498
1525
else :
1499
1526
audio_codes = batch ['audio_codes' ]
1500
1527
audio_codes_lens = batch ['audio_codes_lens' ]
1528
+ if self ._codec_converter :
1529
+ audio_codes = self ._codec_converter .convert_original_to_new (
1530
+ audio_tokens = audio_codes , audio_lens = audio_codes_lens
1531
+ ).long ()
1501
1532
if self .frame_stacking_factor > 1 :
1502
1533
# repeat the BOS token to frame_stacking_factor times. This is necessary since at inference
1503
1534
# we need to start autoregressive generation from a full stack indicating BOS.
@@ -2326,6 +2357,7 @@ def infer_batch(
2326
2357
all_codes_next_argmax = self .sample_codes_from_logits (
2327
2358
all_code_logits_t ,
2328
2359
temperature = 0.01 ,
2360
+ topk = 1 ,
2329
2361
unfinished_items = unfinished_items ,
2330
2362
finished_items = finished_items ,
2331
2363
forbid_audio_eos = forbid_audio_eos ,
0 commit comments