Skip to content

Commit b520970

Browse files
authored
Dev/skotapati/copy broadcasting (#350)
* Handle broadcasting by expanding src tensor in Copy.mm * Unblock linalg_matrix_power * Improved formatting
1 parent c30946a commit b520970

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

aten/src/ATen/native/mps/operations/Copy.mm

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,22 +300,27 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
300300
TORCH_CHECK(dst.defined(), "dst is undefined");
301301
TORCH_CHECK(src.defined(), "src is undefined");
302302

303+
bool needs_broadcasting = false;
304+
303305
if (src.numel() == 0 || dst.is_same(src)) {
304306
return dst;
305307
}
306308
if (dst.numel() == 0) {
307309
dst.resize_as_(src);
308310
}
311+
if (dst.dim() > src.dim()) {
312+
needs_broadcasting = true;
313+
}
309314

310315
if (src.device().type() == at::kMPS && dst.device().type() == at::kCPU) {
311-
return copy_from_mps_(dst, src, non_blocking);
316+
return copy_from_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
312317
}
313318
if (src.device().type() == at::kCPU && dst.device().type() == at::kMPS) {
314-
return copy_to_mps_(dst, src, non_blocking);
319+
return copy_to_mps_(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
315320
}
316321

317322
if (src.device().type() == at::kMPS && dst.device().type() == at::kMPS) {
318-
return copy_kernel_mps(dst, src, non_blocking);
323+
return copy_kernel_mps(dst, needs_broadcasting ? src.expand_as(dst) : src, non_blocking);
319324
}
320325
TORCH_INTERNAL_ASSERT(
321326
src.device().type() == DeviceType::MPS,

test/test_mps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10200,7 +10200,6 @@ class TestConsistency(TestCaseMPS):
1020010200
# All the entries in this list should be removed
1020110201
BLOCKLIST = {
1020210202
# Functions that hard crash
10203-
'linalg.matrix_power': [torch.float32],
1020410203
'resize_': [torch.bool, torch.float16, torch.float32, torch.int16, torch.int32, torch.int64, torch.uint8],
1020510204
'resize_as_': [torch.float16, torch.float32],
1020610205
'topk': [torch.int16, torch.int32, torch.int64, torch.uint8],

0 commit comments

Comments
 (0)