@@ -95,6 +95,7 @@ def __init__(
9595 depformer_weights_per_step_schedule : list [int ] | None = None ,
9696 depformer_low_rank_embeddings : int | None = None ,
9797 depformer_pos_emb : str = "sin" ,
98+ depformer_norm : str | None = None ,
9899 existing_text_padding_id : int = 3 ,
99100 existing_text_end_padding_id : int = 0 ,
100101 extra_heads_num_heads : int = 0 ,
@@ -193,6 +194,11 @@ def __init__(
193194 depformer_dim ,
194195 demux_second_stream = demux_second_text_stream ,
195196 )
197+ if depformer_norm is None :
198+ self .depformer_norms = nn .ModuleList ([nn .Identity () for _ in range (dep_q )])
199+ else :
200+ self .depformer_norms = nn .ModuleList (
201+ [create_norm_fn (depformer_norm , depformer_dim ) for _ in range (dep_q )])
196202 if depformer_dim_feedforward is None :
197203 depformer_dim_feedforward = int (hidden_scale * depformer_dim )
198204 self .depformer = StreamingTransformer (
@@ -435,7 +441,7 @@ def forward_depformer_training(
435441 depformer_output = self .depformer (depformer_input )
436442 all_logits = []
437443 for cb_index in range (Ka ):
438- logits = self .linears [cb_index ](depformer_output [:, cb_index ])
444+ logits = self .linears [cb_index ](self . depformer_norms [ cb_index ]( depformer_output [:, cb_index ]) )
439445 all_logits .append (logits .view (B , T , - 1 ))
440446 logits = torch .stack (all_logits , 1 )
441447 assert logits .dim () == 4 , logits .shape # [B, Ka, T, card]
@@ -481,7 +487,7 @@ def forward_depformer(
481487 # depformer_input is [B, 1, depformer_dim].
482488 # The streaming state of the depformer ensures that the proper layer is run.
483489 dep_output = self .depformer (depformer_input )
484- logits = self .linears [depformer_cb_index ](dep_output )
490+ logits = self .linears [depformer_cb_index ](self . depformer_norms [ depformer_cb_index ]( dep_output ) )
485491 logits = logits [:, None ]
486492 assert logits .dim () == 4 , logits .shape # [B, Ka, S, card]
487493 return logits
0 commit comments