@@ -801,26 +801,65 @@ def call(x, td):
801
801
802
802
803
803
@pytest .mark .skipif (not _v2_5 , reason = "Requires PT>=2.5" )
804
+ @pytest .mark .parametrize ("strict" , [True , False ])
804
805
class TestExport :
805
- def test_export_module (self ):
806
+ def test_export_module (self , strict ):
806
807
torch ._dynamo .reset_code_caches ()
807
808
tdm = Mod (lambda x , y : x * y , in_keys = ["x" , "y" ], out_keys = ["z" ])
808
809
x = torch .randn (3 )
809
810
y = torch .randn (3 )
810
- out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
811
+ out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y }, strict = strict )
811
812
assert (out .module ()(x = x , y = y ) == tdm (x = x , y = y )).all ()
812
813
813
- def test_export_seq (self ):
814
+ def test_export_seq (self , strict ):
814
815
torch ._dynamo .reset_code_caches ()
815
816
tdm = Seq (
816
817
Mod (lambda x , y : x * y , in_keys = ["x" , "y" ], out_keys = ["z" ]),
817
818
Mod (lambda z , x : z + x , in_keys = ["z" , "x" ], out_keys = ["out" ]),
818
819
)
819
820
x = torch .randn (3 )
820
821
y = torch .randn (3 )
821
- out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
822
+ out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y }, strict = strict )
822
823
torch .testing .assert_close (out .module ()(x = x , y = y ), tdm (x = x , y = y ))
823
824
825
+ @pytest .mark .parametrize (
826
+ "same_shape,dymanic_shape" , [[True , True ], [True , False ], [False , True ]]
827
+ )
828
+ def test_td_output (self , strict , same_shape , dymanic_shape ):
829
+ # This will only work when the tensordict is pytree-able
830
+ class Test (torch .nn .Module ):
831
+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
832
+ return TensorDict (
833
+ {
834
+ "x" : x ,
835
+ "y" : y ,
836
+ },
837
+ batch_size = x .shape [0 ],
838
+ )
839
+
840
+ test = Test ()
841
+ if same_shape :
842
+ x , y = torch .zeros (5 , 100 ), torch .zeros (5 , 100 )
843
+ else :
844
+ x , y = torch .zeros (2 , 100 ), torch .zeros (2 , 100 )
845
+ if dymanic_shape :
846
+ kwargs = {
847
+ "dynamic_shapes" : {
848
+ "x" : {0 : torch .export .Dim ("batch" ), 1 : torch .export .Dim ("time" )},
849
+ "y" : {0 : torch .export .Dim ("batch" ), 1 : torch .export .Dim ("time" )},
850
+ }
851
+ }
852
+ else :
853
+ kwargs = {}
854
+
855
+ result = torch .export .export (test , args = (x , y ), strict = False , ** kwargs )
856
+ export_mod = result .module ()
857
+ x_new , y_new = torch .zeros (5 , 100 ), torch .zeros (5 , 100 )
858
+ export_test = export_mod (x_new , y_new )
859
+ eager_test = test (x_new , y_new )
860
+ assert eager_test .batch_size == export_test .batch_size
861
+ assert (export_test == eager_test ).all ()
862
+
824
863
825
864
@pytest .mark .skipif (not _has_onnx , reason = "ONNX is not available" )
826
865
class TestONNXExport :
0 commit comments