@@ -774,26 +774,65 @@ def call(x, td):
774
774
775
775
776
776
@pytest .mark .skipif (not _v2_5 , reason = "Requires PT>=2.5" )
777
+ @pytest .mark .parametrize ("strict" , [True , False ])
777
778
class TestExport :
778
- def test_export_module (self ):
779
+ def test_export_module (self , strict ):
779
780
torch ._dynamo .reset_code_caches ()
780
781
tdm = Mod (lambda x , y : x * y , in_keys = ["x" , "y" ], out_keys = ["z" ])
781
782
x = torch .randn (3 )
782
783
y = torch .randn (3 )
783
- out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
784
+ out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y }, strict = strict )
784
785
assert (out .module ()(x = x , y = y ) == tdm (x = x , y = y )).all ()
785
786
786
- def test_export_seq (self ):
787
+ def test_export_seq (self , strict ):
787
788
torch ._dynamo .reset_code_caches ()
788
789
tdm = Seq (
789
790
Mod (lambda x , y : x * y , in_keys = ["x" , "y" ], out_keys = ["z" ]),
790
791
Mod (lambda z , x : z + x , in_keys = ["z" , "x" ], out_keys = ["out" ]),
791
792
)
792
793
x = torch .randn (3 )
793
794
y = torch .randn (3 )
794
- out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
795
+ out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y }, strict = strict )
795
796
torch .testing .assert_close (out .module ()(x = x , y = y ), tdm (x = x , y = y ))
796
797
798
+ @pytest .mark .parametrize (
799
+ "same_shape,dymanic_shape" , [[True , True ], [True , False ], [False , True ]]
800
+ )
801
+ def test_td_output (self , strict , same_shape , dymanic_shape ):
802
+ # This will only work when the tensordict is pytree-able
803
+ class Test (torch .nn .Module ):
804
+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
805
+ return TensorDict (
806
+ {
807
+ "x" : x ,
808
+ "y" : y ,
809
+ },
810
+ batch_size = x .shape [0 ],
811
+ )
812
+
813
+ test = Test ()
814
+ if same_shape :
815
+ x , y = torch .zeros (5 , 100 ), torch .zeros (5 , 100 )
816
+ else :
817
+ x , y = torch .zeros (2 , 100 ), torch .zeros (2 , 100 )
818
+ if dymanic_shape :
819
+ kwargs = {
820
+ "dynamic_shapes" : {
821
+ "x" : {0 : torch .export .Dim ("batch" ), 1 : torch .export .Dim ("time" )},
822
+ "y" : {0 : torch .export .Dim ("batch" ), 1 : torch .export .Dim ("time" )},
823
+ }
824
+ }
825
+ else :
826
+ kwargs = {}
827
+
828
+ result = torch .export .export (test , args = (x , y ), strict = False , ** kwargs )
829
+ export_mod = result .module ()
830
+ x_new , y_new = torch .zeros (5 , 100 ), torch .zeros (5 , 100 )
831
+ export_test = export_mod (x_new , y_new )
832
+ eager_test = test (x_new , y_new )
833
+ assert eager_test .batch_size == export_test .batch_size
834
+ assert (export_test == eager_test ).all ()
835
+
797
836
798
837
@pytest .mark .skipif (not _has_onnx , reason = "ONNX is not available" )
799
838
class TestONNXExport :
0 commit comments