Skip to content

Commit 55d5a01

Browse files
Magpietts 2508 Attention mask bug fix (#14836)
* bug fix in attention mask Signed-off-by: Paarth Neekhara <[email protected]> * Apply isort and black reformatting Signed-off-by: paarthneekhara <[email protected]> * handle None as well Signed-off-by: Paarth Neekhara <[email protected]> * Apply isort and black reformatting Signed-off-by: paarthneekhara <[email protected]> * Added tests and handled masking in convolutional layer Signed-off-by: Paarth Neekhara <[email protected]> * Apply isort and black reformatting Signed-off-by: paarthneekhara <[email protected]> --------- Signed-off-by: Paarth Neekhara <[email protected]> Signed-off-by: paarthneekhara <[email protected]> Co-authored-by: paarthneekhara <[email protected]>
1 parent f3878d7 commit 55d5a01

File tree

2 files changed

+82
-10
lines changed

2 files changed

+82
-10
lines changed

nemo/collections/tts/modules/transformer_2501.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,15 @@ def __init__(
8282
bias=bias,
8383
)
8484

85-
def forward(self, signal):
85+
def forward(self, signal, signal_mask):
86+
# signal: (B, C, T)
87+
# signal_mask: (B, T)
88+
signal = signal * signal_mask.unsqueeze(1)
8689
if self.is_causal: # TODO: maybe replace with identify rather than keep conditional if in forward
8790
signal = F.pad(signal, self.causal_padding)
8891

8992
conv_signal = self.conv(signal)
93+
conv_signal = conv_signal * signal_mask.unsqueeze(1)
9094

9195
return conv_signal
9296

@@ -126,12 +130,13 @@ def __init__(
126130
self.o_net = ConvolutionLayer(d_ffn, d_model, bias=bias, kernel_size=kernel_size, is_causal=is_causal)
127131
self.dropout = torch.nn.Dropout(p_dropout)
128132

129-
def forward(self, x):
133+
def forward(self, x, x_mask):
130134
"""
131135
x (B, T, C)
136+
x_mask (B, T)
132137
"""
133-
x = self.non_linearity(self.proj(x.transpose(1, 2)))
134-
x = self.dropout(self.o_net(x).transpose(1, 2))
138+
x = self.non_linearity(self.proj(x.transpose(1, 2), x_mask))
139+
x = self.dropout(self.o_net(x, x_mask).transpose(1, 2))
135140
return x
136141

137142

@@ -350,7 +355,14 @@ def compute_qkv_and_mask(
350355
v = torch.cat([self.cache['self_v'], v], dim=1)
351356
self.cache['self_k'] = k
352357
self.cache['self_v'] = v
353-
mask = query_mask[:, None, :, None] if query_mask is not None else None
358+
359+
mask = None
360+
if query_mask is not None:
361+
# query_mask is a boolean mask of shape (B, T)
362+
# mask should be of shape (B, 1, T, T) where mask[:,0,i,:] == mask[:,0,:,i] == query_mask
363+
mask = query_mask.unsqueeze(1) * query_mask.unsqueeze(2)
364+
mask = mask.unsqueeze(1)
365+
354366
return q, k, v, mask
355367

356368

@@ -551,7 +563,7 @@ def forward(
551563
x = x + x_res
552564

553565
# mlp final projection
554-
x = x + self.pos_ff(self.norm_pos_ff(x))
566+
x = x + self.pos_ff(self.norm_pos_ff(x), x_mask)
555567
x = x * x_mask.unsqueeze(-1)
556568

557569
return {

tests/collections/tts/modules/test_transformer_2501.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def set_seed(seed):
3737
random.seed(seed)
3838

3939

40+
from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths
41+
42+
4043
@pytest.mark.unit
4144
class TestConvolutionLayer:
4245
@classmethod
@@ -53,6 +56,7 @@ def setup_class(cls):
5356
[-1.0317, 1.6818, 1.4257, -0.5003, -1.7254, 0.8830, -0.4541, -0.4631, -0.0986, 0.5083],
5457
[-0.3231, -1.0899, 0.5774, 0.1661, 0.9620, -2.3307, -0.6158, -0.3663, 1.2469, -1.0208]]]
5558
)
59+
cls.input_mask = torch.ones(1, cls.input_tensor.shape[2])
5660
# fmt:on
5761

5862
def test_non_causal_forward(self):
@@ -68,7 +72,7 @@ def test_non_causal_forward(self):
6872
)
6973

7074
with torch.no_grad():
71-
output_tensor = layer(self.input_tensor)
75+
output_tensor = layer(self.input_tensor, self.input_mask)
7276

7377
# fmt:off
7478
expected_output_tensor = torch.Tensor(
@@ -96,7 +100,7 @@ def test_causal_forward(self):
96100
)
97101

98102
with torch.no_grad():
99-
output_tensor = layer(self.input_tensor)
103+
output_tensor = layer(self.input_tensor, self.input_mask)
100104

101105
# fmt:off
102106
expected_output_tensor = torch.Tensor(
@@ -133,6 +137,7 @@ def setup_class(cls):
133137
[-0.1543, 0.3365, 1.7475],
134138
[-0.1753, 0.4115, 0.0772]]]
135139
)
140+
cls.input_mask = torch.ones(1, cls.input_tensor.shape[1])
136141
# fmt:on
137142

138143
def test_causal_forward(self):
@@ -142,7 +147,7 @@ def test_causal_forward(self):
142147
)
143148

144149
with torch.no_grad():
145-
output_tensor = layer(self.input_tensor)
150+
output_tensor = layer(self.input_tensor, self.input_mask)
146151

147152
# fmt:off
148153
expected_output_tensor = torch.Tensor(
@@ -168,7 +173,7 @@ def test_non_causal_forward(self):
168173
)
169174

170175
with torch.no_grad():
171-
output_tensor = layer(self.input_tensor)
176+
output_tensor = layer(self.input_tensor, self.input_mask)
172177

173178
# fmt:off
174179
expected_output_tensor = torch.Tensor(
@@ -795,3 +800,58 @@ def test_forward_causal_self_attn_and_has_xattn(self):
795800
expected_output["attn_probabilities"][i]["cross_attn_probabilities"][0],
796801
atol=1e-4,
797802
)
803+
804+
805+
@pytest.mark.unit
806+
class TestTransformerBatchedInference:
807+
@classmethod
808+
def setup_class(cls):
809+
cls.n_layers = 3
810+
cls.d_model = 4
811+
cls.d_ffn = 16
812+
cls.sa_n_heads = 2
813+
cls.p_dropout = 0.0
814+
cls.p_dropout_out = 0.0
815+
cls.max_length_causal_mask = 10
816+
cls.short_length = 4
817+
cls.long_length = 10
818+
819+
def test_forward(self):
820+
set_seed(0)
821+
query_tensor1 = torch.randn(1, self.long_length, self.d_model)
822+
query_tensor2 = torch.randn(1, self.short_length, self.d_model)
823+
824+
padding_tensor = torch.randn(1, self.long_length - self.short_length, self.d_model)
825+
query_tensor2_padded = torch.cat([query_tensor2, padding_tensor], dim=1)
826+
lengths = torch.tensor([self.long_length, self.short_length])
827+
mask_batched = get_mask_from_lengths(lengths)
828+
829+
query_batched = torch.cat([query_tensor1, query_tensor2_padded], dim=0)
830+
831+
mask_bs1_1 = torch.ones(1, self.long_length)
832+
mask_bs1_2 = torch.ones(1, self.short_length)
833+
834+
for is_causal in [True, False]:
835+
for kernel_size in [1, 3]:
836+
model = Transformer(
837+
n_layers=self.n_layers,
838+
d_model=self.d_model,
839+
d_ffn=self.d_ffn,
840+
sa_n_heads=self.sa_n_heads,
841+
kernel_size=kernel_size,
842+
p_dropout=self.p_dropout,
843+
p_dropout_out=self.p_dropout_out,
844+
is_causal=is_causal,
845+
max_length_causal_mask=self.max_length_causal_mask,
846+
)
847+
848+
output_batched = model(query_batched, mask_batched)
849+
output_bs1_1 = model(query_tensor1, mask_bs1_1)
850+
output_bs1_2 = model(query_tensor2, mask_bs1_2)
851+
852+
assert torch.allclose(
853+
output_batched['output'][0][: self.long_length, :], output_bs1_1['output'], atol=1e-4
854+
)
855+
assert torch.allclose(
856+
output_batched['output'][1][: self.short_length, :], output_bs1_2['output'], atol=1e-4
857+
)

0 commit comments

Comments
 (0)