Skip to content

Commit 298b2d6

Browse files
Add support for masked_scatter (#361)
* Add support for masked_scatter * Fix lintrunner
1 parent 5f928e8 commit 298b2d6

File tree

3 files changed

+73
-2
lines changed

3 files changed

+73
-2
lines changed

aten/src/ATen/native/mps/operations/Indexing.mm

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,56 @@ Tensor embedding_dense_backward_mps(
942942
return masked_fill__mps(self, mask, value.item());
943943
}
944944

945+
Tensor & masked_scatter__mps(Tensor& self, const Tensor& mask, const Tensor& source) {
946+
at::assert_no_internal_overlap(self);
947+
TORCH_CHECK(
948+
self.scalar_type() == source.scalar_type(),
949+
"masked_scatter: expected self and source to have same dtypes but got",
950+
self.scalar_type(),
951+
" and ",
952+
source.scalar_type());
953+
954+
if (self.numel() == 0) {
955+
return self;
956+
}
957+
958+
TORCH_CHECK(mask.scalar_type() == ScalarType::Byte || mask.scalar_type() == ScalarType::Bool,
959+
"masked_scatter: expected BoolTensor or ByteTensor for mask");
960+
961+
auto mask_temp = (mask.dim() == 0)
962+
? c10::MaybeOwned<Tensor>::owned(mask.unsqueeze(0))
963+
: c10::MaybeOwned<Tensor>::borrowed(mask);
964+
auto self_temp = (self.dim() == 0)
965+
? c10::MaybeOwned<Tensor>::owned(self.unsqueeze(0))
966+
: c10::MaybeOwned<Tensor>::borrowed(self);
967+
968+
// Cannot reassign to mask_temp and self_temp here! if they are
969+
// owning and expand_outplace returns a borrow, the returned borrow
970+
// would dangle.
971+
auto mask_self_expanded = expand_outplace(*mask_temp, *self_temp);
972+
auto indices = at::native::expandTensors(
973+
*std::get<1>(mask_self_expanded),
974+
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))})
975+
);
976+
// next broadcast all index tensors together
977+
try {
978+
indices = at::expand_outplace(indices);
979+
} catch (std::exception &e) {
980+
TORCH_CHECK_INDEX(false, "shape mismatch: indexing tensors could not be broadcast together");
981+
}
982+
983+
if (!indices[0].has_storage() || indices[0].numel() == 0) {
984+
return self;
985+
}
986+
987+
return at::index_put_out(
988+
self,
989+
*std::get<1>(mask_self_expanded),
990+
c10::List<c10::optional<at::Tensor>>({*std::move(std::get<0>(mask_self_expanded))}),
991+
source.resize_(indices[0].numel())
992+
);
993+
}
994+
945995
REGISTER_DISPATCH(index_stub, &index_kernel_mps);
946996
REGISTER_DISPATCH(index_put_stub, &index_put_kernel_mps);
947997
} // namespace at::native

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7422,6 +7422,7 @@
74227422
dispatch:
74237423
CPU: masked_scatter__cpu
74247424
CUDA: masked_scatter__cuda
7425+
MPS: masked_scatter__mps
74257426
autogen: masked_scatter.out
74267427

74277428
- func: masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor

test/test_mps.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,27 @@ def helper(size, memory_format):
10031003

10041004
helper((2, 3, 6, 6), torch.contiguous_format)
10051005

1006+
def test_masked_scatter(self):
1007+
def helper(shape):
1008+
x_mps = torch.randn(shape, device="mps")
1009+
x_cpu = x_mps.detach().clone().cpu()
1010+
1011+
mask_mps = torch.rand(shape, device="mps") < 0.6
1012+
mask_cpu = mask_mps.detach().clone().cpu()
1013+
1014+
y_mps = torch.randn(shape, device="mps")
1015+
y_cpu = y_mps.detach().clone().cpu()
1016+
1017+
y_mps.masked_scatter_(mask_mps, x_mps)
1018+
y_cpu.masked_scatter_(mask_cpu, x_cpu)
1019+
1020+
self.assertEqual(y_mps, y_cpu)
1021+
helper([2, 5])
1022+
helper([10, 10])
1023+
helper([5, 10, 3])
1024+
helper([10, 5, 10, 3])
1025+
helper([10, 5, 10, 3, 20])
1026+
10061027
def test_masked_fill(self):
10071028
device = "mps"
10081029
dtype = torch.float32
@@ -9304,7 +9325,7 @@ class TestConsistency(TestCaseMPS):
93049325
'masked.std': ['f32', 'i16', 'i32', 'i64', 'u8'],
93059326
'masked.var': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
93069327
'masked_fill': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
9307-
'masked_scatter': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
9328+
'masked_scatter': ['i8', 'b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
93089329
'masked_select': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
93099330
'matmul': ['f32', 'i16', 'i32', 'i64', 'u8'],
93109331
'matrix_exp': ['f32'],
@@ -10425,7 +10446,6 @@ class TestConsistency(TestCaseMPS):
1042510446
'lu_unpack': [torch.float32],
1042610447
'masked.cumprod': [torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
1042710448
'masked.median': [torch.float32],
10428-
'masked_scatter': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
1042910449
'matrix_exp': [torch.float32],
1043010450
'mode': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],
1043110451
'msort': [torch.bool, torch.float32, torch.float16, torch.int16, torch.int32, torch.int64, torch.uint8],

0 commit comments

Comments
 (0)