@@ -37,6 +37,9 @@ def set_seed(seed):
37
37
random .seed (seed )
38
38
39
39
40
+ from nemo .collections .tts .parts .utils .helpers import get_mask_from_lengths
41
+
42
+
40
43
@pytest .mark .unit
41
44
class TestConvolutionLayer :
42
45
@classmethod
@@ -53,6 +56,7 @@ def setup_class(cls):
53
56
[- 1.0317 , 1.6818 , 1.4257 , - 0.5003 , - 1.7254 , 0.8830 , - 0.4541 , - 0.4631 , - 0.0986 , 0.5083 ],
54
57
[- 0.3231 , - 1.0899 , 0.5774 , 0.1661 , 0.9620 , - 2.3307 , - 0.6158 , - 0.3663 , 1.2469 , - 1.0208 ]]]
55
58
)
59
+ cls .input_mask = torch .ones (1 , cls .input_tensor .shape [2 ])
56
60
# fmt:on
57
61
58
62
def test_non_causal_forward (self ):
@@ -68,7 +72,7 @@ def test_non_causal_forward(self):
68
72
)
69
73
70
74
with torch .no_grad ():
71
- output_tensor = layer (self .input_tensor )
75
+ output_tensor = layer (self .input_tensor , self . input_mask )
72
76
73
77
# fmt:off
74
78
expected_output_tensor = torch .Tensor (
@@ -96,7 +100,7 @@ def test_causal_forward(self):
96
100
)
97
101
98
102
with torch .no_grad ():
99
- output_tensor = layer (self .input_tensor )
103
+ output_tensor = layer (self .input_tensor , self . input_mask )
100
104
101
105
# fmt:off
102
106
expected_output_tensor = torch .Tensor (
@@ -133,6 +137,7 @@ def setup_class(cls):
133
137
[- 0.1543 , 0.3365 , 1.7475 ],
134
138
[- 0.1753 , 0.4115 , 0.0772 ]]]
135
139
)
140
+ cls .input_mask = torch .ones (1 , cls .input_tensor .shape [1 ])
136
141
# fmt:on
137
142
138
143
def test_causal_forward (self ):
@@ -142,7 +147,7 @@ def test_causal_forward(self):
142
147
)
143
148
144
149
with torch .no_grad ():
145
- output_tensor = layer (self .input_tensor )
150
+ output_tensor = layer (self .input_tensor , self . input_mask )
146
151
147
152
# fmt:off
148
153
expected_output_tensor = torch .Tensor (
@@ -168,7 +173,7 @@ def test_non_causal_forward(self):
168
173
)
169
174
170
175
with torch .no_grad ():
171
- output_tensor = layer (self .input_tensor )
176
+ output_tensor = layer (self .input_tensor , self . input_mask )
172
177
173
178
# fmt:off
174
179
expected_output_tensor = torch .Tensor (
@@ -795,3 +800,58 @@ def test_forward_causal_self_attn_and_has_xattn(self):
795
800
expected_output ["attn_probabilities" ][i ]["cross_attn_probabilities" ][0 ],
796
801
atol = 1e-4 ,
797
802
)
803
+
804
+
805
+ @pytest .mark .unit
806
+ class TestTransformerBatchedInference :
807
+ @classmethod
808
+ def setup_class (cls ):
809
+ cls .n_layers = 3
810
+ cls .d_model = 4
811
+ cls .d_ffn = 16
812
+ cls .sa_n_heads = 2
813
+ cls .p_dropout = 0.0
814
+ cls .p_dropout_out = 0.0
815
+ cls .max_length_causal_mask = 10
816
+ cls .short_length = 4
817
+ cls .long_length = 10
818
+
819
+ def test_forward (self ):
820
+ set_seed (0 )
821
+ query_tensor1 = torch .randn (1 , self .long_length , self .d_model )
822
+ query_tensor2 = torch .randn (1 , self .short_length , self .d_model )
823
+
824
+ padding_tensor = torch .randn (1 , self .long_length - self .short_length , self .d_model )
825
+ query_tensor2_padded = torch .cat ([query_tensor2 , padding_tensor ], dim = 1 )
826
+ lengths = torch .tensor ([self .long_length , self .short_length ])
827
+ mask_batched = get_mask_from_lengths (lengths )
828
+
829
+ query_batched = torch .cat ([query_tensor1 , query_tensor2_padded ], dim = 0 )
830
+
831
+ mask_bs1_1 = torch .ones (1 , self .long_length )
832
+ mask_bs1_2 = torch .ones (1 , self .short_length )
833
+
834
+ for is_causal in [True , False ]:
835
+ for kernel_size in [1 , 3 ]:
836
+ model = Transformer (
837
+ n_layers = self .n_layers ,
838
+ d_model = self .d_model ,
839
+ d_ffn = self .d_ffn ,
840
+ sa_n_heads = self .sa_n_heads ,
841
+ kernel_size = kernel_size ,
842
+ p_dropout = self .p_dropout ,
843
+ p_dropout_out = self .p_dropout_out ,
844
+ is_causal = is_causal ,
845
+ max_length_causal_mask = self .max_length_causal_mask ,
846
+ )
847
+
848
+ output_batched = model (query_batched , mask_batched )
849
+ output_bs1_1 = model (query_tensor1 , mask_bs1_1 )
850
+ output_bs1_2 = model (query_tensor2 , mask_bs1_2 )
851
+
852
+ assert torch .allclose (
853
+ output_batched ['output' ][0 ][: self .long_length , :], output_bs1_1 ['output' ], atol = 1e-4
854
+ )
855
+ assert torch .allclose (
856
+ output_batched ['output' ][1 ][: self .short_length , :], output_bs1_2 ['output' ], atol = 1e-4
857
+ )
0 commit comments