Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 110 additions & 1 deletion torch/csrc/distributed/c10d/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,24 @@ c10::intrusive_ptr<Work> reduce_cpu_(
std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> reduce_mps_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
return process_group->getBackend(c10::DeviceType::MPS)
->reduce(
tensor_vec,
ReduceOptions{
*reduce_op.get(),
root_rank,
root_tensor,
std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> reduce_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
Expand Down Expand Up @@ -172,6 +190,24 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cpu_(
std::move(tensor_vec), work);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_mps_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t root_tensor,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work =
process_group->getBackend(c10::DeviceType::MPS)
->broadcast(
tensor_vec,
BroadcastOptions{
root_rank, root_tensor, std::chrono::milliseconds(timeout)});

return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(tensor_vec), work);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
Expand Down Expand Up @@ -210,6 +246,26 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cpu_(
std::move(tensor_vec), work);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_mps_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
const c10::intrusive_ptr<ReduceOp>& reduce_op,
int64_t timeout) {
auto tensor_vec = tensors.vec();
auto work =
process_group->getBackend(c10::DeviceType::MPS)
->allreduce(
tensor_vec,
AllreduceOptions{
*reduce_op.get(), std::chrono::milliseconds(timeout)});

// Return input tensors as output tensors to make inplace allreduce look like
// a functional API, so that make_fx can correctly build the dependencies in
// the graph later.
return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(tensor_vec), work);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cuda_(
at::TensorList tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
Expand Down Expand Up @@ -286,7 +342,7 @@ allgather_mps_(
int64_t timeout) {
auto input_tensors_vec = input_tensors.vec();
auto work =
process_group->getBackend(c10::DeviceType::CPU)
process_group->getBackend(c10::DeviceType::MPS)
->allgather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
input_tensors_vec,
Expand Down Expand Up @@ -448,6 +504,21 @@ c10::intrusive_ptr<Work> gather_cpu_(
input_tensors_vec,
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> gather_mps_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const at::TensorList& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t timeout) {
auto input_tensors_vec = input_tensors.vec();
return process_group->getBackend(c10::DeviceType::MPS)
->gather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
input_tensors_vec,
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> gather_cuda_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const at::TensorList& input_tensors,
Expand Down Expand Up @@ -480,6 +551,24 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cpu_(
std::move(output_tensors_vec), work);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_mps_(
const at::TensorList& output_tensors,
const std::vector<std::vector<at::Tensor>>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t timeout) {
auto output_tensors_vec = output_tensors.vec();
auto work =
process_group->getBackend(c10::DeviceType::MPS)
->scatter(
output_tensors_vec,
const_cast<std::vector<std::vector<at::Tensor>>&>(input_tensors),
ScatterOptions{root_rank, std::chrono::milliseconds(timeout)});

return std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>(
std::move(output_tensors_vec), work);
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cuda_(
const at::TensorList& output_tensors,
const std::vector<std::vector<at::Tensor>>& input_tensors,
Expand Down Expand Up @@ -622,6 +711,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("reduce_", reduce_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, MPS, m) {
m.impl("reduce_", reduce_mps_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("reduce_", reduce_cuda_);
}
Expand All @@ -630,6 +723,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("broadcast_", broadcast_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, MPS, m) {
m.impl("broadcast_", broadcast_mps_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("broadcast_", broadcast_cuda_);
}
Expand All @@ -638,6 +735,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("allreduce_", allreduce_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, MPS, m) {
m.impl("allreduce_", allreduce_mps_);
}

// TODO: The SparseCPU/SparseCUDA dispatched methods are only used to support
// sparse all_reduce in the Gloo backend
TORCH_LIBRARY_IMPL(c10d, SparseCPU, m) {
Expand Down Expand Up @@ -708,6 +809,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("gather_", gather_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, MPS, m) {
m.impl("gather_", gather_mps_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("gather_", gather_cuda_);
}
Expand All @@ -716,6 +821,10 @@ TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("scatter_", scatter_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, MPS, m) {
m.impl("scatter_", scatter_mps_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("scatter_", scatter_cuda_);
}
Expand Down
Loading