Skip to content

Commit 6358c94

Browse files
authored
adding depformer norm (#379)
1 parent b6f674f commit 6358c94

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

moshi/moshi/models/lm.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __init__(
9595
depformer_weights_per_step_schedule: list[int] | None = None,
9696
depformer_low_rank_embeddings: int | None = None,
9797
depformer_pos_emb: str = "sin",
98+
depformer_norm: str | None = None,
9899
existing_text_padding_id: int = 3,
99100
existing_text_end_padding_id: int = 0,
100101
extra_heads_num_heads: int = 0,
@@ -193,6 +194,11 @@ def __init__(
193194
depformer_dim,
194195
demux_second_stream=demux_second_text_stream,
195196
)
197+
if depformer_norm is None:
198+
self.depformer_norms = nn.ModuleList([nn.Identity() for _ in range(dep_q)])
199+
else:
200+
self.depformer_norms = nn.ModuleList(
201+
[create_norm_fn(depformer_norm, depformer_dim) for _ in range(dep_q)])
196202
if depformer_dim_feedforward is None:
197203
depformer_dim_feedforward = int(hidden_scale * depformer_dim)
198204
self.depformer = StreamingTransformer(
@@ -435,7 +441,7 @@ def forward_depformer_training(
435441
depformer_output = self.depformer(depformer_input)
436442
all_logits = []
437443
for cb_index in range(Ka):
438-
logits = self.linears[cb_index](depformer_output[:, cb_index])
444+
logits = self.linears[cb_index](self.depformer_norms[cb_index](depformer_output[:, cb_index]))
439445
all_logits.append(logits.view(B, T, -1))
440446
logits = torch.stack(all_logits, 1)
441447
assert logits.dim() == 4, logits.shape # [B, Ka, T, card]
@@ -481,7 +487,7 @@ def forward_depformer(
481487
# depformer_input is [B, 1, depformer_dim].
482488
# The streaming state of the depformer ensures that the proper layer is run.
483489
dep_output = self.depformer(depformer_input)
484-
logits = self.linears[depformer_cb_index](dep_output)
490+
logits = self.linears[depformer_cb_index](self.depformer_norms[depformer_cb_index](dep_output))
485491
logits = logits[:, None]
486492
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
487493
return logits

scripts/import_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def import_model(
4343
'dim', 'text_card', 'existing_text_padding_id', 'num_heads', 'num_layers', 'hidden_scale', 'causal',
4444
'layer_scale', 'context', 'max_period', 'gating', 'norm', 'positional_embedding',
4545
'depformer_dim', 'depformer_num_heads', 'depformer_num_layers', 'depformer_dim_feedforward',
46-
'depformer_layer_scale', 'depformer_multi_linear',
46+
'depformer_layer_scale', 'depformer_multi_linear', 'depformer_norm',
4747
'depformer_max_period', 'depformer_gating', 'depformer_pos_emb', 'depformer_weights_per_step',
4848
'depformer_low_rank_embeddings', 'demux_second_stream',
4949
'text_card_out']

0 commit comments

Comments
 (0)