Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/_mac-test-mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ jobs:
# TODO(https://github.com/pytorch/pytorch/issues/79293)
${CONDA_RUN} python3 test/test_nn.py -k mps --verbose

- name: Run MPS Test Ops
id: test_3
env:
ENV_NAME: conda-test-env-${{ github.run_id }}
shell: arch -arch arm64 bash {0}
# During bring up of NN don't show this as an error.
continue-on-error: true
run: |
# shellcheck disable=SC1090
set -ex
# TODO(https://github.com/pytorch/pytorch/issues/79293)
${CONDA_RUN} PYTORCH_TEST_WITH_SLOW=1 python3 test/test_ops.py -k mps --verbose

- name: Print remaining test logs
shell: bash
if: always()
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
# shellcheck disable=SC1090
set -ex
set +e
if ! ${CONDA_RUN} lintrunner --force-color test/*.py aten/src/ATen/native/mps/*.h aten/src/ATen/native/mps/*.mm aten/src/ATen/native/mps/operations/*; then
if ! ${CONDA_RUN} lintrunner --force-color aten/src/ATen/native/mps/operations/* test/test_mps.py test/test_modules.py test/test_ops.py; then
echo ""
echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m"
echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m"
Expand Down
141 changes: 132 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.testing._internal.common_dtype import (
floating_and_complex_types_and,
all_types_and_complex_and,
get_all_dtypes,
)

from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -110,6 +111,17 @@

aten = torch.ops.aten

MPS_DTYPES = get_all_dtypes()
for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]:
del MPS_DTYPES[MPS_DTYPES.index(t)]

def _get_mps_error_msg(device, dtype, op, mps_blocklist):
if torch.backends.mps.is_available() and device == "mps" and dtype not in MPS_DTYPES:
return f"MPS doesn't support {str(dtype)} datatype"
if op.name.startswith(tuple(mps_blocklist)):
return "MPS doesn't support op " + str(op.name)
return None

# Tests that apply to all operators and aren't related to any particular
# system
class TestCommon(TestCase):
Expand Down Expand Up @@ -256,12 +268,18 @@ def test_numpy_ref(self, device, dtype, op):
)

# Tests that the cpu and gpu results are consistent
@onlyCUDA
@suppress_warnings
@slowTest
@ops(_ops_and_refs_with_no_numpy_ref, dtypes=OpDTypes.any_common_cpu_cuda_one)
def test_compare_cpu(self, device, dtype, op):

MPS_BLOCKLIST = [
"stft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)

def to_cpu(arg):
if isinstance(arg, torch.Tensor):
return arg.to(device='cpu')
Expand All @@ -271,20 +289,20 @@ def to_cpu(arg):

for sample in samples:
cpu_sample = sample.transform(to_cpu)
cuda_results = op(sample.input, *sample.args, **sample.kwargs)
gpu_results = op(sample.input, *sample.args, **sample.kwargs)
cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs)

# output_process_fn_grad has a very unfortunate name
# We use this function in linalg extensively to postprocess the inputs of functions
# that are not completely well-defined. Think svd and muliplying the singular vectors by -1.
# CPU and CUDA implementations of the SVD can return valid SVDs that are different.
# We use this function to compare them.
cuda_results = sample.output_process_fn_grad(cuda_results)
gpu_results = sample.output_process_fn_grad(gpu_results)
cpu_results = cpu_sample.output_process_fn_grad(cpu_results)

# Lower tolerance because we are running this as a `@slowTest`
# Don't want the periodic tests to fail frequently
self.assertEqual(cuda_results, cpu_results, atol=1e-3, rtol=1e-3)
self.assertEqual(gpu_results, cpu_results, atol=1e-3, rtol=1e-3)

# Tests that experimental Python References can propagate shape, dtype,
# and device metadata properly.
Expand Down Expand Up @@ -479,11 +497,24 @@ def test_python_ref_torch_fallback(self, device, dtype, op):
self._ref_test_helper(contextlib.nullcontext, device, dtype, op)

@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyCUDA
@ops(python_ref_db)
@parametrize('executor', ['aten', 'nvfuser'])
@skipIfTorchInductor("Takes too long for inductor")
def test_python_ref_executor(self, device, dtype, op, executor):
if device == "mps" and executor == 'nvfuser':
return
MPS_BLOCKLIST = [
"_refs.fft.fft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.rfft2", # hard crash on unsupoorted ComplexFloat
"_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat
"_refs.floor_divide", # hard crash on unsupoorted ComplexFloat
"_refs.where", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
# TODO: Not all dtypes are supported with nvfuser
from torch._prims_common import _torch_dtype_to_nvfuser_dtype_map
if executor == "nvfuser" and dtype not in _torch_dtype_to_nvfuser_dtype_map:
Expand Down Expand Up @@ -663,6 +694,12 @@ def test_noncontiguous_samples(self, device, dtype, op):
@ops(_ops_and_refs, dtypes=OpDTypes.none)
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_out_warning(self, device, op):
MPS_BLOCKLIST = [
"_refs.fft.fft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat
]
# Prefers running in float32 but has a fallback for the first listed supported dtype
supported_dtypes = op.supported_dtypes(self.device_type)
if len(supported_dtypes) == 0:
Expand All @@ -673,6 +710,9 @@ def test_out_warning(self, device, op):
else list(supported_dtypes)[0]
)

msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
samples = op.sample_inputs(device, dtype)
for sample in samples:
# calls it normally to get the expected result
Expand Down Expand Up @@ -716,7 +756,7 @@ def _extract_strides(out):
# NOTE: only extracts on the CPU and CUDA device types since some
# device types don't have storage
def _extract_data_ptrs(out):
if self.device_type != "cpu" and self.device_type != "cuda":
if self.device_type != "cpu" and self.device_type != "cuda" and self.device_type != "mps":
return ()

if isinstance(out, torch.Tensor):
Expand Down Expand Up @@ -792,6 +832,23 @@ def _any_nonempty(out):
@ops(_ops_and_refs, dtypes=OpDTypes.any_one)
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_out(self, device, dtype, op):
MPS_BLOCKLIST = [
"_refs._conversions.complex", # hard crash on unsupoorted ComplexFloat
"_refs.fft.fft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.ifft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.ihfft", # hard crash on unsupoorted ComplexFloat
"_refs.fft.rfft2", # hard crash on unsupoorted ComplexFloat
"_refs.fft.rfft", # hard crash on unsupoorted ComplexFloat
"bitwise_not", # hard crash on unsupoorted ComplexFloat
"fft.fft", # hard crash on unsupoorted ComplexFloat
"fft.ifft", # hard crash on unsupoorted ComplexFloat
"fft.ihfft", # hard crash on unsupoorted ComplexFloat
"fft.rfft2", # hard crash on unsupoorted ComplexFloat
"fft.rfft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
# Prefers running in float32 but has a fallback for the first listed supported dtype
samples = op.sample_inputs(device, dtype)
for sample in samples:
Expand Down Expand Up @@ -836,7 +893,7 @@ def _extract_strides(out):
# NOTE: only extracts on the CPU and CUDA device types since some
# device types don't have storage
def _extract_data_ptrs(out):
if self.device_type != "cpu" and self.device_type != "cuda":
if self.device_type != "cpu" and self.device_type != "cuda" and self.device_type != "mps":
return ()

if isinstance(out, torch.Tensor):
Expand Down Expand Up @@ -980,7 +1037,18 @@ def _case_four_transform(t):
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_variant_consistency_eager(self, device, dtype, op):
# Acquires variants (method variant, inplace variant, operator variant, inplace_operator variant, aliases)

MPS_BLOCKLIST = [
"fft.fft", # hard crash on unsupoorted ComplexFloat
"fft.ifft", # hard crash on unsupoorted ComplexFloat
"fft.ihfft", # hard crash on unsupoorted ComplexFloat
"fft.rfft2", # hard crash on unsupoorted ComplexFloat
"fft.rfft", # hard crash on unsupoorted ComplexFloat
"nn.functional.max_pool2d", # hard crash: buffer is not large enough
"stft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
method = op.method_variant
inplace = op.inplace_variant
operator = op.operator_variant
Expand Down Expand Up @@ -1163,6 +1231,9 @@ def _test_inplace_preserve_storage(samples, variants):
@ops(op_db, allowed_dtypes=(torch.complex32,))
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_complex_half_reference_testing(self, device, dtype, op):
msg = _get_mps_error_msg(device, dtype, op, [])
if msg is not None:
self.skipTest(msg)
if not op.supports_dtype(torch.complex32, device):
unittest.skip("Does not support complex32")

Expand Down Expand Up @@ -1431,6 +1502,12 @@ class TestCompositeCompliance(TestCase):
)
@ops(op_db, allowed_dtypes=(torch.float,))
def test_operator(self, device, dtype, op):
MPS_BLOCKLIST = [
"stft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
samples = op.sample_inputs(device, dtype, requires_grad=False)

for sample in samples:
Expand All @@ -1444,6 +1521,18 @@ def test_operator(self, device, dtype, op):
)
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
def test_backward(self, device, dtype, op):
MPS_BLOCKLIST = [
"fft.fft", # hard crash on unsupoorted ComplexFloat
"fft.ifft", # hard crash on unsupoorted ComplexFloat
"fft.ihfft", # hard crash on unsupoorted ComplexFloat
"fft.rfft2", # hard crash on unsupoorted ComplexFloat
"fft.rfft", # hard crash on unsupoorted ComplexFloat
"stft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)

samples = op.sample_inputs(device, dtype, requires_grad=True)

for sample in samples:
Expand All @@ -1461,6 +1550,12 @@ def test_backward(self, device, dtype, op):
)
@ops(op_db, allowed_dtypes=(torch.float,))
def test_forward_ad(self, device, dtype, op):
MPS_BLOCKLIST = [
"stft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
if torch.float not in op.supported_backward_dtypes(device):
raise unittest.SkipTest("Does not support autograd")

Expand Down Expand Up @@ -1594,6 +1689,10 @@ def clone_and_perform_view(input, **kwargs):
@ops(ops_and_refs, allowed_dtypes=(torch.cfloat,))
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_conj_view(self, device, dtype, op):
msg = _get_mps_error_msg(device, dtype, op, [])
if msg is not None:
self.skipTest(msg)

if not op.test_conjugated_samples:
self.skipTest("Operation doesn't support conjugated inputs.")
math_op_physical = torch.conj_physical
Expand Down Expand Up @@ -1637,6 +1736,9 @@ def test_neg_view(self, device, dtype, op):
@ops(ops_and_refs, allowed_dtypes=(torch.cdouble,))
@skipIfTorchInductor("Inductor does not support complex dtype yet")
def test_neg_conj_view(self, device, dtype, op):
msg = _get_mps_error_msg(device, dtype, op, [])
if msg is not None:
self.skipTest(msg)
if not op.test_neg_view:
self.skipTest("Operation not tested with tensors with negative bit.")
if not op.test_conjugated_samples:
Expand Down Expand Up @@ -2012,6 +2114,17 @@ def test_refs_are_in_decomp_table(self, op):

class TestFakeTensor(TestCase):
def _test_fake_helper(self, device, dtype, op, context):
if(device == "cpu"):
return
MPS_BLOCKLIST = [
"bfloat16", # hard crash on unsupoorted type byte size
"cdouble", # hard crash on unsupoorted type byte size
"cfloat", # hard crash on unsupoorted type byte size
"chalf", # hard crash on unsupoorted type byte size
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
name = op.name
if op.variant_test_name:
name += "." + op.variant_test_name
Expand Down Expand Up @@ -2163,10 +2276,20 @@ def _test_fake_crossref_helper(self, device, dtype, op, context):
op.gradcheck_wrapper)

@skipIfRocm
@onlyCUDA
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
@skipOps('TestFakeTensor', 'test_fake_crossref_backward_no_amp', fake_backward_xfails)
def test_fake_crossref_backward_no_amp(self, device, dtype, op):
MPS_BLOCKLIST = [
"fft.fft", # hard crash on unsupoorted ComplexFloat
"fft.ifft", # hard crash on unsupoorted ComplexFloat
"fft.ihfft", # hard crash on unsupoorted ComplexFloat
"fft.rfft2", # hard crash on unsupoorted ComplexFloat
"fft.rfft", # hard crash on unsupoorted ComplexFloat
"stft", # hard crash on unsupoorted ComplexFloat
]
msg = _get_mps_error_msg(device, dtype, op, MPS_BLOCKLIST)
if msg is not None:
self.skipTest(msg)
self._test_fake_crossref_helper(device, dtype, op, contextlib.nullcontext)

@skipIfRocm
Expand Down