3333import torch .backends .mps
3434from torch .distributions import Uniform , Exponential
3535from functools import partial , reduce
36-
36+ from test_mps_utils import LoggingTensor , capture_logs , tracefunc
3737from torch .testing ._internal .common_methods_invocations import (
3838 op_db ,
3939 UnaryUfuncInfo ,
@@ -9466,6 +9466,7 @@ class TestConsistency(TestCaseMPS):
94669466 'nonzero' : ['b8' , 'u8' , 'f16' , 'f32' , 'i16' , 'i32' , 'i64' ],
94679467 'norm' : ['f32' , 'f16' ],
94689468 'normal' : ['f16' , 'f32' ],
9469+ 'normal_' : ['f16' , 'f32' ],
94699470 'ones' : ['b8' , 'f16' , 'f32' , 'i16' , 'i32' , 'i64' , 'u8' ],
94709471 'ones_like' : ['b8' , 'f16' , 'f32' , 'i16' , 'i32' , 'i64' , 'u8' ],
94719472 'ormqr' : ['f32' ],
@@ -10543,6 +10544,8 @@ class TestConsistency(TestCaseMPS):
1054310544 # Failures due to unsupported data types on MPS backend
1054410545 'bfloat16' : [torch .bool , torch .float16 , torch .float32 , torch .int16 , torch .int32 , torch .int64 , torch .uint8 ],
1054510546 'chalf' : [torch .bool , torch .float16 , torch .float32 , torch .int16 , torch .int32 , torch .int64 , torch .uint8 ],
10547+ # Byte tests are failing
10548+ 'byte' : [torch .float16 , torch .float32 ],
1054610549 'nn.functional.conv1d' : [torch .int64 ],
1054710550 'nn.functional.conv2d' : [torch .int64 ],
1054810551 'nn.functional.conv_transpose1d' : [torch .int64 ],
@@ -10626,12 +10629,14 @@ class TestConsistency(TestCaseMPS):
1062610629 # Failures due to random output that they generate using
1062710630 # Philox engine causing mismatch with CPU results
1062810631 'uniform' : [torch .float16 , torch .float32 ],
10632+ 'randn' : [torch .float16 , torch .float32 ],
1062910633 'rand_like' : [torch .float16 , torch .float32 ],
1063010634 'randint_like' : [torch .float16 , torch .float32 , torch .int16 , torch .int32 , torch .int64 , torch .uint8 ],
1063110635 'randn_like' : [torch .float16 , torch .float32 ],
1063210636 'bernoulli' : [torch .float32 ],
1063310637 'nn.functional.feature_alpha_dropoutwith_train' : [torch .float32 ],
1063410638 'normal' : [torch .float16 , torch .float32 , torch .float16 , torch .float32 ],
10639+ 'normal_' : [torch .float16 , torch .float32 ],
1063510640 'normalnumber_mean' : [torch .float16 , torch .float32 ],
1063610641 'nn.functional.alpha_dropout' : [torch .float32 ],
1063710642 'nn.functional.dropout' : [torch .float32 ],
@@ -10723,6 +10728,7 @@ def compare_with_CUDA(self, op, mps_out, atol, rtol):
1072310728
1072410729 @ops (op_db , allowed_dtypes = MPS_DTYPES )
1072510730 def test_output_match (self , device , dtype , op ):
10731+ # sys.setprofile(tracefunc)
1072610732 self .assertEqual (device , "cpu" )
1072710733 if not torch .backends .mps .is_available ():
1072810734 self .skipTest ("MPS is not available" )
@@ -10777,6 +10783,10 @@ def get_samples():
1077710783
1077810784 # TODO: This checks only the function variant. We should also check the method and inplace version
1077910785 # when they exist
10786+
10787+ if os .environ .get ("DUMP_MPS_OPS" , None ) == "1" :
10788+ mps_sample .input = LoggingTensor (mps_sample .input )
10789+
1078010790 cpu_args = [cpu_sample .input ] + list (cpu_sample .args )
1078110791 cpu_kwargs = cpu_sample .kwargs
1078210792 mps_args = [mps_sample .input ] + list (mps_sample .args )
@@ -10786,8 +10796,20 @@ def get_samples():
1078610796 if (op .name == "tensor_split" and isinstance (mps_args [1 ], torch .Tensor )):
1078710797 mps_args [1 ] = cpu_args [1 ]
1078810798
10789- cpu_out = op (* cpu_args , ** cpu_kwargs )
10790- mps_out = op (* mps_args , ** mps_kwargs )
10799+ # Skip running the tests to generate full list
10800+ if os .environ .get ("EXPECTTEST_ACCEPT" , None ) == "1" :
10801+ continue
10802+
10803+ if os .environ .get ("DUMP_MPS_OPS" , None ) == "1" :
10804+ with capture_logs () as logs :
10805+ cpu_out = op (* cpu_args , ** cpu_kwargs )
10806+ mps_out = op (* mps_args , ** mps_kwargs )
10807+ print ("Forward logs:" )
10808+ print ("\n " .join (logs ))
10809+ else :
10810+ cpu_out = op (* cpu_args , ** cpu_kwargs )
10811+ mps_out = op (* mps_args , ** mps_kwargs )
10812+
1079110813
1079210814 if op .name == "nn.functional.conv2d" or op .name == "linalg.multi_dot" and dtype == torch .float32 :
1079310815 atol = 1e-4
@@ -10867,8 +10889,15 @@ def req_grad(t):
1086710889 # Compare computed gradients with cpu given random grad_output vector
1086810890 # Sometimes when the derivative is 0, we just don't bother creating the graph
1086910891 # allow_unused is needed in those cases.
10870- cpu_grad_inputs = torch .autograd .grad (diff_cpu_out , diff_cpu_arg , grad_outputs = cpu_grad_outputs , allow_unused = True )
10871- mps_grad_inputs = torch .autograd .grad (diff_mps_out , diff_mps_arg , grad_outputs = mps_grad_outputs , allow_unused = True )
10892+ if os .environ .get ("DUMP_MPS_OPS" , None ) == "1" :
10893+ with capture_logs () as logs :
10894+ cpu_grad_inputs = torch .autograd .grad (diff_cpu_out , diff_cpu_arg , grad_outputs = cpu_grad_outputs , allow_unused = True )
10895+ mps_grad_inputs = torch .autograd .grad (diff_mps_out , diff_mps_arg , grad_outputs = mps_grad_outputs , allow_unused = True )
10896+ print ("Backward logs:" )
10897+ print ("\n " .join (logs ))
10898+ else :
10899+ cpu_grad_inputs = torch .autograd .grad (diff_cpu_out , diff_cpu_arg , grad_outputs = cpu_grad_outputs , allow_unused = True )
10900+ mps_grad_inputs = torch .autograd .grad (diff_mps_out , diff_mps_arg , grad_outputs = mps_grad_outputs , allow_unused = True )
1087210901
1087310902 self .assertEqual (cpu_grad_inputs , mps_grad_inputs , atol = atol , rtol = rtol )
1087410903 except Exception as e :
0 commit comments