Skip to content

Commit 0b48537

Browse files
rlangmanblisc
andauthored
Add new spectral codec definition (#14794)
* [TTS] Add new spectral codec definition Signed-off-by: Ryan <[email protected]> * Add codec MMD loss definitions Signed-off-by: Ryan <[email protected]> * Apply isort and black reformatting Signed-off-by: rlangman <[email protected]> --------- Signed-off-by: Ryan <[email protected]> Signed-off-by: rlangman <[email protected]> Signed-off-by: Jason <[email protected]> Co-authored-by: Jason <[email protected]> Co-authored-by: rlangman <[email protected]>
1 parent 6b5d25b commit 0b48537

File tree

5 files changed

+529
-32
lines changed

5 files changed

+529
-32
lines changed

nemo/collections/common/parts/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,16 @@ def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor):
159159

160160
class ClampActivation(nn.Module):
161161

162-
def __init__(self, min_value: float = -1.0, max_value: float = 1.0):
162+
def __init__(self, min_value: float = -1.0, max_value: float = 1.0, clamp_training: bool = True):
163163
super().__init__()
164164
self.min_value = min_value
165165
self.max_value = max_value
166+
self.clamp_training = clamp_training
166167

167168
def forward(self, input: torch.Tensor) -> torch.Tensor:
169+
if self.training and not self.clamp_training:
170+
return input
171+
168172
return torch.clamp(input, min=self.min_value, max=self.max_value)
169173

170174

nemo/collections/tts/losses/audio_codec_loss.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,176 @@ def forward(self, disc_scores_real, disc_scores_gen):
512512
loss /= len(disc_scores_real)
513513

514514
return loss
515+
516+
517+
class MMDLoss(Loss):
518+
"""
519+
Maximum mean discrepancy (MMD) loss, as defined in https://arxiv.org/abs/2406.02315
520+
521+
Args:
522+
kernel_radii: List of radii for Gaussian kernels
523+
loss_scale: Constant to multiply loss by
524+
"""
525+
526+
def __init__(self, kernel_radii=(0.1, 1, 5, 10, 20, 50), loss_scale=1.0):
527+
super().__init__()
528+
self.kernel_radii = kernel_radii
529+
self.loss_scale = loss_scale
530+
531+
@staticmethod
532+
def _exp_kernel(dxx, r):
533+
return torch.exp((-0.5 / r) * dxx).sum()
534+
535+
@staticmethod
536+
def _shuffle_codebooks(x):
537+
B, C, _ = x.size()
538+
x_shuffled = torch.zeros_like(x)
539+
for c in range(C):
540+
batch_perm = torch.randperm(B, device=x.device)
541+
x_shuffled[:, c, :] = x[batch_perm, c, :]
542+
return x_shuffled
543+
544+
@property
545+
def input_types(self):
546+
return {
547+
"inputs": [NeuralType(('B', 'C', 'D'), VoidType())],
548+
}
549+
550+
@property
551+
def output_types(self):
552+
return {"loss": NeuralType(elements_type=LossType())}
553+
554+
@typecheck()
555+
def forward(self, inputs):
556+
B, C, D = inputs.size()
557+
558+
x = inputs
559+
x_mean = x.mean(dim=(0,), keepdim=True)
560+
x_stdev = torch.sqrt(x.var(dim=(0,), keepdim=True) + 1e-8)
561+
x = (x - x_mean) / x_stdev
562+
y = self._shuffle_codebooks(x)
563+
564+
# [B, C * D]
565+
x = x.reshape([B, C * D])
566+
y = y.reshape([B, C * D])
567+
568+
# [B, B]
569+
xx = torch.mm(x, x.t())
570+
yy = torch.mm(y, y.t())
571+
zz = torch.mm(x, y.t())
572+
573+
rx = xx.diag().unsqueeze(0).expand_as(xx)
574+
ry = yy.diag().unsqueeze(0).expand_as(yy)
575+
576+
dxx = rx.t() + rx - 2.0 * xx
577+
dyy = ry.t() + ry - 2.0 * yy
578+
dxy = rx.t() + ry - 2.0 * zz
579+
580+
loss = 0.0
581+
coeff = -2.0 / B**2
582+
denom = B * (B - 1)
583+
for r in self.kernel_radii:
584+
loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dxx, r) - B) / denom
585+
loss += coeff * torch.utils.checkpoint.checkpoint(self._exp_kernel, dxy, r)
586+
loss += (torch.utils.checkpoint.checkpoint(self._exp_kernel, dyy, r) - B) / denom
587+
588+
loss = loss.clamp(min=0)
589+
loss = self.loss_scale * loss
590+
return loss
591+
592+
593+
class MMDCodebookLoss(Loss):
594+
"""
595+
MMD loss which incentivizes independence between codebooks within each timestep.
596+
597+
Args:
598+
num_codebooks: Number of codebooks.
599+
codebook_dim: Dimension of a single codebook code.
600+
loss_fn: MMDLoss instance.
601+
"""
602+
603+
def __init__(self, num_codebooks, codebook_dim, loss_fn):
604+
super().__init__()
605+
self.num_codebooks = num_codebooks
606+
self.codebook_dim = codebook_dim
607+
self.loss_fn = loss_fn
608+
609+
@property
610+
def input_types(self):
611+
return {
612+
"inputs": [NeuralType(('B', 'D', 'T'), VoidType())],
613+
}
614+
615+
@property
616+
def output_types(self):
617+
return {"loss": NeuralType(elements_type=LossType())}
618+
619+
@typecheck()
620+
def forward(self, inputs):
621+
B, D, T = inputs.size()
622+
623+
# [B, C, D / C, T]
624+
x = inputs.reshape(B, self.num_codebooks, self.codebook_dim, T)
625+
# [B*T, C, D / C]
626+
x = rearrange(x, 'B C D T -> (B T) C D')
627+
loss = self.loss_fn(inputs=x)
628+
return loss
629+
630+
631+
class MMDEmbeddingLoss(Loss):
632+
"""
633+
MMD loss which incentivizes independence between embedding values within each timestep.
634+
635+
Args:
636+
loss_fn: MMDLoss instance.
637+
"""
638+
639+
def __init__(self, loss_fn):
640+
super().__init__()
641+
self.loss_fn = loss_fn
642+
643+
@property
644+
def input_types(self):
645+
return {
646+
"inputs": [NeuralType(('B', 'D', 'T'), VoidType())],
647+
}
648+
649+
@property
650+
def output_types(self):
651+
return {"loss": NeuralType(elements_type=LossType())}
652+
653+
@typecheck()
654+
def forward(self, inputs):
655+
# [B*T, 1, D]
656+
x = rearrange(inputs, 'B D T -> (B T) D 1')
657+
loss = self.loss_fn(inputs=x)
658+
return loss
659+
660+
661+
class MMDTimeLoss(Loss):
662+
"""
663+
MMD loss which incentivizes independence between different timesteps.
664+
665+
Args:
666+
loss_fn: MMDLoss instance.
667+
"""
668+
669+
def __init__(self, loss_fn):
670+
super().__init__()
671+
self.loss_fn = loss_fn
672+
673+
@property
674+
def input_types(self):
675+
return {
676+
"inputs": [NeuralType(('B', 'D', 'T'), VoidType())],
677+
}
678+
679+
@property
680+
def output_types(self):
681+
return {"loss": NeuralType(elements_type=LossType())}
682+
683+
@typecheck()
684+
def forward(self, inputs):
685+
x = rearrange(inputs, 'B D T -> B T D')
686+
loss = self.loss_fn(inputs=x)
687+
return loss

nemo/collections/tts/models/audio_codec.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,22 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
143143
self.gen_loss_fn = instantiate(cfg.generator_loss)
144144
self.disc_loss_fn = instantiate(cfg.discriminator_loss)
145145

146+
self.mmd_loss_start_epoch = cfg.get("mmd_loss_start_epoch", 0)
147+
148+
if "mmd_loss" in cfg:
149+
self.mmd_loss_fn = instantiate(cfg.mmd_loss)
150+
self.mmd_loss_scale = cfg.get("mmd_loss_scale", 1.0)
151+
else:
152+
self.mmd_loss_fn = None
153+
self.mmd_loss_scale = None
154+
155+
if "mmd_time_loss" in cfg:
156+
self.mmd_time_loss_fn = instantiate(cfg.mmd_time_loss)
157+
self.mmd_time_loss_scale = cfg.get("mmd_time_loss_scale", 1.0)
158+
else:
159+
self.mmd_time_loss_fn = None
160+
self.mmd_time_loss_scale = None
161+
146162
feature_loss_type = cfg.get("feature_loss_type", "relative")
147163
if feature_loss_type == "relative":
148164
self.feature_loss_fn = RelativeFeatureMatchingLoss()
@@ -497,7 +513,7 @@ def _process_batch(self, batch):
497513
encoded = encoded.to(self.dtype) # make sure vector quantizer output is in the model dtype
498514
audio_gen, _ = self.audio_decoder(inputs=encoded, input_len=encoded_len)
499515

500-
return audio, audio_len, audio_gen, commit_loss
516+
return audio, audio_len, audio_gen, commit_loss, encoded
501517

502518
@property
503519
def disc_update_prob(self) -> float:
@@ -514,7 +530,7 @@ def should_update_disc(self, batch_idx) -> bool:
514530
def training_step(self, batch, batch_idx):
515531
optim_gen, optim_disc = self.optimizers()
516532

517-
audio, audio_len, audio_gen, commit_loss = self._process_batch(batch)
533+
audio, audio_len, audio_gen, commit_loss, codes = self._process_batch(batch)
518534

519535
metrics = {
520536
"global_step": self.global_step,
@@ -578,6 +594,19 @@ def training_step(self, batch, batch_idx):
578594
metrics["g_loss_commit"] = commit_loss
579595
generator_losses.append(self.commit_loss_scale * commit_loss)
580596

597+
if self.mmd_loss_scale:
598+
loss_mmd = self.mmd_loss_fn(inputs=codes)
599+
metrics["g_loss_mmd"] = loss_mmd
600+
601+
if self.current_epoch >= self.mmd_loss_start_epoch:
602+
generator_losses.append(self.mmd_loss_scale * loss_mmd)
603+
604+
if self.mmd_time_loss_scale:
605+
loss_mmd_time = self.mmd_time_loss_fn(inputs=codes)
606+
metrics["g_loss_mmd_time"] = loss_mmd_time
607+
if self.current_epoch >= self.mmd_loss_start_epoch:
608+
generator_losses.append(self.mmd_time_loss_scale * loss_mmd_time)
609+
581610
# compute embeddings for speaker consistency loss
582611
if self.use_scl_loss:
583612
# concate generated and GT waveforms
@@ -623,7 +652,7 @@ def on_train_epoch_end(self):
623652
self.update_lr("epoch")
624653

625654
def validation_step(self, batch, batch_idx):
626-
audio, audio_len, audio_gen, _ = self._process_batch(batch)
655+
audio, audio_len, audio_gen, _, _ = self._process_batch(batch)
627656

628657
loss_mel_l1, loss_mel_l2 = self.mel_loss_fn(
629658
audio_real=audio.float(), audio_gen=audio_gen.float(), audio_len=audio_len

nemo/collections/tts/models/magpietts.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from nemo.collections.tts.models import AudioCodecModel
3737
from nemo.collections.tts.modules import transformer_2501
3838
from nemo.collections.tts.modules.aligner import AlignmentEncoder
39+
from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter
3940
from nemo.collections.tts.modules.magpietts_modules import (
4041
CharAwareSubwordEncoder,
4142
EOSDetectionMethod,
@@ -95,17 +96,32 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
9596

9697
# load codec
9798
codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path'), strict=False)
99+
98100
self.sample_rate = codec_model.sample_rate
101+
self.codec_model_samples_per_frame = codec_model.samples_per_frame
99102
# del codec discriminator to free memory
100103
del codec_model.discriminator
101104

102-
# Set up codebook configuration
103-
self.num_audio_codebooks = codec_model.num_codebooks
104-
self.codec_model_samples_per_frame = codec_model.samples_per_frame
105+
# When using FSQ tokens, the codebook structure can be changed at any time.
106+
# An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure
107+
# that is different than in the audio codec checkpoint.
108+
vector_quantizer = cfg.get('vector_quantizer')
109+
if vector_quantizer is not None:
110+
vector_quantizer = instantiate(vector_quantizer)
111+
self.num_audio_codebooks = vector_quantizer.num_codebooks
112+
self.codebook_size = vector_quantizer.codebook_size
113+
codec_converter = VectorQuantizerIndexConverter(
114+
vector_quantizer_original=codec_model.vector_quantizer,
115+
vector_quantizer_new=vector_quantizer,
116+
)
117+
else:
118+
self.num_audio_codebooks = codec_model.num_codebooks
119+
self.codebook_size = codec_model.codebook_size
120+
codec_converter = None
121+
105122
# Our codebooks start with actual audio codec tokens, followed by special tokens.
106123
# The `forced_*` options are for backward compatibility for models trained with older code.
107-
num_audio_tokens = codec_model.codebook_size
108-
get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=num_audio_tokens)
124+
get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size)
109125
self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS))
110126
self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS))
111127
self.context_audio_bos_id = cfg.get(
@@ -116,7 +132,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
116132
)
117133
self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN))
118134
self.num_all_tokens_per_codebook = cfg.get(
119-
'forced_num_all_tokens_per_codebook', num_audio_tokens + len(SpecialAudioToken)
135+
'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken)
120136
)
121137
self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False)
122138

@@ -201,6 +217,7 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
201217
# This needs to happen after super().__init__()
202218
self._codec_model = codec_model
203219
self._codec_model.freeze() # Lightning does requires_grad = False and self.eval()
220+
self._codec_converter = codec_converter
204221

205222
audio_embeddings = []
206223
for _ in range(self.num_audio_codebooks * self.frame_stacking_factor):
@@ -450,6 +467,8 @@ def audio_to_codes(self, audio, audio_len, audio_type='target'):
450467
self._codec_model.eval()
451468
with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32):
452469
codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len)
470+
if self._codec_converter is not None:
471+
codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len)
453472
# Add a timestep to begining and end of codes tensor
454473
bos_tensor = torch.full(
455474
(codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device
@@ -478,6 +497,10 @@ def codes_to_audio(self, codes, codes_len):
478497
codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token
479498
codes_copy[codes == self.audio_eos_id] = 0
480499
# Pass the modified integer token IDs
500+
if self._codec_converter is not None:
501+
codes_copy = self._codec_converter.convert_new_to_original(
502+
audio_tokens=codes_copy, audio_lens=codes_len
503+
)
481504
audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len)
482505
# audio: (B, T)
483506
# audio_len: (B,)
@@ -744,7 +767,7 @@ def clear_forbidden_logits(self, logits, forbid_audio_eos=False):
744767
logits[
745768
:,
746769
:,
747-
SpecialAudioToken.get_forbidden_tokens(self._codec_model.codebook_size, forbid_audio_eos=forbid_audio_eos),
770+
SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos),
748771
] = float('-inf')
749772
return logits
750773

@@ -1276,6 +1299,10 @@ def prepare_context_tensors(self, batch):
12761299
if 'context_audio_codes' in batch:
12771300
context_audio_codes = batch['context_audio_codes']
12781301
context_audio_codes_lens = batch['context_audio_codes_lens']
1302+
if self._codec_converter is not None:
1303+
context_audio_codes = self._codec_converter.convert_original_to_new(
1304+
audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens
1305+
).long()
12791306
else:
12801307
context_audio_codes, context_audio_codes_lens = self.audio_to_codes(
12811308
batch['context_audio'], batch['context_audio_lens'], audio_type='context'
@@ -1498,6 +1525,10 @@ def process_batch(self, batch, mode="train"):
14981525
else:
14991526
audio_codes = batch['audio_codes']
15001527
audio_codes_lens = batch['audio_codes_lens']
1528+
if self._codec_converter:
1529+
audio_codes = self._codec_converter.convert_original_to_new(
1530+
audio_tokens=audio_codes, audio_lens=audio_codes_lens
1531+
).long()
15011532
if self.frame_stacking_factor > 1:
15021533
# repeat the BOS token to frame_stacking_factor times. This is necessary since at inference
15031534
# we need to start autoregressive generation from a full stack indicating BOS.
@@ -2326,6 +2357,7 @@ def infer_batch(
23262357
all_codes_next_argmax = self.sample_codes_from_logits(
23272358
all_code_logits_t,
23282359
temperature=0.01,
2360+
topk=1,
23292361
unfinished_items=unfinished_items,
23302362
finished_items=finished_items,
23312363
forbid_audio_eos=forbid_audio_eos,

0 commit comments

Comments
 (0)