diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index 8e6d648059d2b..998bc8e335125 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -136,6 +136,24 @@ c10::intrusive_ptr reduce_cpu_( std::chrono::milliseconds(timeout)}); } +c10::intrusive_ptr reduce_mps_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + return process_group->getBackend(c10::DeviceType::MPS) + ->reduce( + tensor_vec, + ReduceOptions{ + *reduce_op.get(), + root_rank, + root_tensor, + std::chrono::milliseconds(timeout)}); +} + c10::intrusive_ptr reduce_cuda_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -172,6 +190,24 @@ std::tuple, c10::intrusive_ptr> broadcast_cpu_( std::move(tensor_vec), work); } +std::tuple, c10::intrusive_ptr> broadcast_mps_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t root_tensor, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::MPS) + ->broadcast( + tensor_vec, + BroadcastOptions{ + root_rank, root_tensor, std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + std::tuple, c10::intrusive_ptr> broadcast_cuda_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -210,6 +246,26 @@ std::tuple, c10::intrusive_ptr> allreduce_cpu_( std::move(tensor_vec), work); } +std::tuple, c10::intrusive_ptr> allreduce_mps_( + at::TensorList tensors, + const c10::intrusive_ptr& process_group, + const c10::intrusive_ptr& reduce_op, + int64_t timeout) { + auto tensor_vec = tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::MPS) + ->allreduce( + tensor_vec, + AllreduceOptions{ + *reduce_op.get(), std::chrono::milliseconds(timeout)}); + + // Return input tensors as output tensors to make inplace allreduce look like + // a functional API, so that make_fx can correctly build the dependencies in + // the graph later. + return std::tuple, c10::intrusive_ptr>( + std::move(tensor_vec), work); +} + std::tuple, c10::intrusive_ptr> allreduce_cuda_( at::TensorList tensors, const c10::intrusive_ptr& process_group, @@ -286,7 +342,7 @@ allgather_mps_( int64_t timeout) { auto input_tensors_vec = input_tensors.vec(); auto work = - process_group->getBackend(c10::DeviceType::CPU) + process_group->getBackend(c10::DeviceType::MPS) ->allgather( const_cast>&>(output_tensors), input_tensors_vec, @@ -448,6 +504,21 @@ c10::intrusive_ptr gather_cpu_( input_tensors_vec, GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); } + +c10::intrusive_ptr gather_mps_( + const std::vector>& output_tensors, + const at::TensorList& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto input_tensors_vec = input_tensors.vec(); + return process_group->getBackend(c10::DeviceType::MPS) + ->gather( + const_cast>&>(output_tensors), + input_tensors_vec, + GatherOptions{root_rank, std::chrono::milliseconds(timeout)}); +} + c10::intrusive_ptr gather_cuda_( const std::vector>& output_tensors, const at::TensorList& input_tensors, @@ -480,6 +551,24 @@ std::tuple, c10::intrusive_ptr> scatter_cpu_( std::move(output_tensors_vec), work); } +std::tuple, c10::intrusive_ptr> scatter_mps_( + const at::TensorList& output_tensors, + const std::vector>& input_tensors, + const c10::intrusive_ptr& process_group, + int64_t root_rank, + int64_t timeout) { + auto output_tensors_vec = output_tensors.vec(); + auto work = + process_group->getBackend(c10::DeviceType::MPS) + ->scatter( + output_tensors_vec, + const_cast>&>(input_tensors), + ScatterOptions{root_rank, std::chrono::milliseconds(timeout)}); + + return std::tuple, c10::intrusive_ptr>( + std::move(output_tensors_vec), work); +} + std::tuple, c10::intrusive_ptr> scatter_cuda_( const at::TensorList& output_tensors, const std::vector>& input_tensors, @@ -622,6 +711,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("reduce_", reduce_cpu_); } +TORCH_LIBRARY_IMPL(c10d, MPS, m) { + m.impl("reduce_", reduce_mps_); +} + TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("reduce_", reduce_cuda_); } @@ -630,6 +723,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("broadcast_", broadcast_cpu_); } +TORCH_LIBRARY_IMPL(c10d, MPS, m) { + m.impl("broadcast_", broadcast_mps_); +} + TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("broadcast_", broadcast_cuda_); } @@ -638,6 +735,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("allreduce_", allreduce_cpu_); } +TORCH_LIBRARY_IMPL(c10d, MPS, m) { + m.impl("allreduce_", allreduce_mps_); +} + // TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support // sparse all_reduce in the Gloo backend TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) { @@ -708,6 +809,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("gather_", gather_cpu_); } +TORCH_LIBRARY_IMPL(c10d, MPS, m) { + m.impl("gather_", gather_mps_); +} + TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("gather_", gather_cuda_); } @@ -716,6 +821,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) { m.impl("scatter_", scatter_cpu_); } +TORCH_LIBRARY_IMPL(c10d, MPS, m) { + m.impl("scatter_", scatter_mps_); +} + TORCH_LIBRARY_IMPL(c10d, CUDA, m) { m.impl("scatter_", scatter_cuda_); } diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 5e85f96906149..73ec48c6450f5 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -886,6 +886,30 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { } }; +class AsyncBroadcastMPSWork : public AsyncBroadcastWork { + public: + AsyncBroadcastMPSWork( + const std::shared_ptr& context, + std::vector& inputs, + int rootRank, + int rootTensor, + uint32_t tag) + : AsyncBroadcastWork(context, inputs, rootRank, rootTensor, tag) {} + + void run() override { + std::vector inputs_cpu; + for( int i = 0; i < inputs.size(); i++) { + inputs_cpu.push_back( inputs[i].to("cpu") ); + } + + broadcast(inputs_cpu[rootTensor]); + + for( int i = 0; i < inputs.size(); i++) { + inputs[i].copy_( inputs_cpu[rootTensor] ); + } + } + }; + class AsyncBroadcastCUDAWork : public AsyncBroadcastWork { public: AsyncBroadcastCUDAWork( @@ -956,6 +980,8 @@ c10::intrusive_ptr ProcessGroupGloo::broadcast( switch (device.type()) { case at::kCPU: break; + case at::kMPS: + break; case at::kCUDA: // If the user gave us a CUDA tensor then CUDA must be loaded. TORCH_INTERNAL_ASSERT(at::hasCUDA()); @@ -970,6 +996,9 @@ c10::intrusive_ptr ProcessGroupGloo::broadcast( if (device.type() == at::kCPU) { work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); + } else if (device.type() == at::kMPS) { + work = c10::make_intrusive( + std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), inputs, opts.rootRank, opts.rootTensor, tag); @@ -1028,6 +1057,28 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { } }; +class AsyncAllreduceMPSWork : public AsyncAllreduceWork { + public: + AsyncAllreduceMPSWork( + const std::shared_ptr& context, + std::vector& inputs, + ReduceOp reduceOp, + uint32_t tag) + : AsyncAllreduceWork(context, inputs, reduceOp, tag) {} + + void run() override { + std::vector inputs_cpu; + for( int i = 0; i < inputs.size(); i++) { + inputs_cpu.push_back( inputs[i].to("cpu") ); + } + allreduce(inputs_cpu); + + for( int i = 0; i < inputs.size(); i++) { + inputs[i].copy_( inputs_cpu[i] ); + } + } +}; + class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { public: AsyncAllreduceCoalescedWork( @@ -1447,6 +1498,8 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( switch (device.type()) { case at::kCPU: break; + case at::kMPS: + break; case at::kCUDA: // If the user gave us a CUDA tensor then CUDA must be loaded. TORCH_INTERNAL_ASSERT(at::hasCUDA()); @@ -1475,6 +1528,16 @@ c10::intrusive_ptr ProcessGroupGloo::allreduce( } else { invalidArgument("unsupported layout"); } + } else if (device.type() == at::kMPS) { + if (layout == c10::kStrided) { + work = c10::make_intrusive( + std::move(context), inputs, opts.reduceOp, tag); + // } else if (layout == c10::kSparse) { + // work = c10::make_intrusive( + // std::move(context), inputs, tag); + } else { + invalidArgument("unsupported layout"); + } } else if (device.type() == at::kCUDA) { if (layout == c10::kStrided) { work = c10::make_intrusive( @@ -1607,6 +1670,31 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { } }; +class AsyncReduceMPSWork : public AsyncReduceWork { + public: + AsyncReduceMPSWork( + const std::shared_ptr& context, + std::vector& inputs, + int rootRank, + int rootTensor, + ReduceOp reduceOp, + uint32_t tag) + : AsyncReduceWork(context, inputs, rootRank, rootTensor, reduceOp, tag) {} + + void run() override { + std::vector inputs_cpu; + for( int i = 0; i < inputs.size(); i++) { + inputs_cpu.push_back( inputs[i].to("cpu") ); + } + + reduce(inputs_cpu); + + for( int i = 0; i < inputs.size(); i++) { + inputs[i].copy_( inputs_cpu[i] ); + } + } +}; + class AsyncReduceCUDAWork : public AsyncReduceWork { public: AsyncReduceCUDAWork( @@ -1678,6 +1766,8 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( switch (device.type()) { case at::kCPU: break; + case at::kMPS: + break; case at::kCUDA: // If the user gave us a CUDA tensor then CUDA must be loaded. TORCH_INTERNAL_ASSERT(at::hasCUDA()); @@ -1697,6 +1787,14 @@ c10::intrusive_ptr ProcessGroupGloo::reduce( opts.rootTensor, opts.reduceOp, tag); + } else if (device.type() == at::kMPS) { + work = c10::make_intrusive( + std::move(context), + inputs, + opts.rootRank, + opts.rootTensor, + opts.reduceOp, + tag); } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), @@ -1791,9 +1889,6 @@ class AsyncAllgatherMPSWork : public AsyncAllgatherWork { allgather(outputs_cpu, inputs_cpu); - for( int i = 0; i < inputs.size(); i++) { - inputs[i].copy_( inputs_cpu[i] ); - } for( int i = 0; i < outputs.size(); i++) { for( int j = 0; j < outputs[i].size(); j++) { outputs[i][j].copy_( outputs_cpu[i][j] ); @@ -2135,6 +2230,42 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { } }; +class AsyncGatherMPSWork : public AsyncGatherWork { + public: + AsyncGatherMPSWork( + const std::shared_ptr& context, + std::vector>& outputs, + std::vector& inputs, + int root, + uint32_t tag) + : AsyncGatherWork(context, outputs, inputs, root, tag) {} + + void run() override { + std::vector inputs_cpu; + std::vector> outputs_cpu; + std::vector temp; + for( int i = 0; i < inputs.size(); i++) { + inputs_cpu.push_back( inputs[i].to("cpu") ); + } + + for( int i = 0; i < outputs.size(); i++) { + temp.clear(); + for( int j = 0; j < outputs[i].size(); j++) { + temp.push_back( outputs[i][j].to("cpu")); + } + outputs_cpu.push_back(temp); + } + + gather(outputs_cpu, inputs_cpu); + + for( int i = 0; i < outputs.size(); i++) { + for( int j = 0; j < outputs[i].size(); j++) { + outputs[i][j].copy_( outputs_cpu[i][j] ); + } + } + } +}; + // Note: current CUDA implementation holds the assumptions: // - inputs.size() is 1 // - outputs.size() is 1 @@ -2252,6 +2383,8 @@ c10::intrusive_ptr ProcessGroupGloo::gather( switch (device.type()) { case at::kCPU: break; + case at::kMPS: + break; case at::kCUDA: // If the user gave us a CUDA tensor then CUDA must be loaded. TORCH_INTERNAL_ASSERT(at::hasCUDA()); @@ -2266,6 +2399,9 @@ c10::intrusive_ptr ProcessGroupGloo::gather( if (device.type() == at::kCPU) { work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); + } else if (device.type() == at::kMPS) { + work = c10::make_intrusive( + std::move(context), outputs, inputs, opts.rootRank, tag); } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); @@ -2326,6 +2462,45 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { } }; +class AsyncScatterMPSWork : public AsyncScatterWork { + public: + AsyncScatterMPSWork( + const std::shared_ptr& context, + std::vector& outputs, + std::vector>& inputs, + int root, + uint32_t tag) + : AsyncScatterWork(context, outputs, inputs, root, tag) {} + + void run() override { + std::vector outputs_cpu; + std::vector> inputs_cpu; + std::vector temp; + for( int i = 0; i < outputs.size(); i++) { + outputs_cpu.push_back( outputs[i].to("cpu") ); + } + + for( int i = 0; i < inputs.size(); i++) { + temp.clear(); + for( int j = 0; j < inputs[i].size(); j++) { + temp.push_back( inputs[i][j].to("cpu")); + } + inputs_cpu.push_back(temp); + } + + scatter(outputs_cpu, inputs_cpu); + + for( int i = 0; i < outputs.size(); i++) { + outputs[i].copy_( outputs_cpu[i] ); + } + for( int i = 0; i < inputs.size(); i++) { + for( int j = 0; j < inputs[i].size(); j++) { + inputs[i][j].copy_( inputs_cpu[i][j] ); + } + } + } +}; + class AsyncScatterCUDAWork : public AsyncScatterWork { public: AsyncScatterCUDAWork( @@ -2435,6 +2610,8 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( switch (device.type()) { case at::kCPU: break; + case at::kMPS: + break; case at::kCUDA: // If the user gave us a CUDA tensor then CUDA must be loaded. TORCH_INTERNAL_ASSERT(at::hasCUDA()); @@ -2449,6 +2626,9 @@ c10::intrusive_ptr ProcessGroupGloo::scatter( if (device.type() == at::kCPU) { work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); + } else if (device.type() == at::kMPS) { + work = c10::make_intrusive( + std::move(context), outputs, inputs, opts.rootRank, tag); } else if (device.type() == at::kCUDA) { work = c10::make_intrusive( std::move(context), outputs, inputs, opts.rootRank, tag); diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 5848c0ecab0ef..6523829456ae5 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -59,6 +59,8 @@ def to_map(obj): device = obj.data.device if isinstance(obj, PackedSequence) else obj.device if device == torch.device("cuda", target_gpu): return (obj,) + if device == torch.device("mps", target_gpu): + return (obj,) if not use_side_stream_for_tensor_copies: return (obj.to(target_gpu),) else: