Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/xccl/ProcessGroupXCCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() {

void ProcessGroupXCCL::WorkXCCL::synchronize() {
synchronizeStream();
if (c10d::allow_inflight_collective_as_graph_input()) {
c10d::unregister_work(
c10::intrusive_ptr<
ProcessGroupXCCL::WorkXCCL>::unsafe_reclaim_from_nonowning(this));
}
}

void ProcessGroupXCCL::WorkXCCL::synchronizeStream() {
Expand Down
81 changes: 81 additions & 0 deletions test/xpu/distributed/test_c10d_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.distributed as c10d
import torch.distributed._functional_collectives as _functional_collectives

if not c10d.is_available() or not c10d.is_xccl_available():
print("c10d XCCL not available, skipping tests", file=sys.stderr)
Expand Down Expand Up @@ -626,6 +627,86 @@ def test_all_gather_into_tensor(self):
tensor.view(torch.float32),
)

@requires_xccl()
@skip_if_lt_x_gpu(2)
def test_unwaited(self) -> None:
# Verify that the process can terminate gracefully
# even with unwaited tensors
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
)

# Case 1: Run collectives under context manager, and don't call wait on them.
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
input = torch.full(
(10240, 10240), float(self.rank), device=f"xpu:{self.rank}"
)
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
# Non-functional collectives run under the context manager is registered in the work registry.
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
# Running another collective on the same tensor should still work
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)

# Case 2: Run collectives not under context manager, and don't call wait on them.
# NOTE: Here we intentionally test memory-stressed case.
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
for _ in range(50000):
input = torch.full(
(1024, 1024), float(self.rank), device=f"xpu:{self.rank}"
)
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
# Work registry size is unchanged, since non-functional collectives not run under
# the context manager is not registered in the work registry.
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)

@requires_xccl()
@skip_if_lt_x_gpu(2)
def test_wait_tensor(self) -> None:
# Verify that c10d_functional.wait_tensor() can be invoked on
# output tensor of non-functional collective
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
)

# Case 1: under context manager (i.e. work is registered in registry)
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
torch.ops.c10d_functional.wait_tensor(input1)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)

input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
work.wait()
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
self.assertEqual(input1, input2)

# Case 2: not under context manager (i.e. work is not registered in registry)
input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
# this does not take effect, since the underlying wait_tensor() logic would not
# be able to find the corresponding work object (because it's not registered in registry)
torch.ops.c10d_functional.wait_tensor(input1)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)

input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
work.wait()
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
self.assertEqual(input1, input2)


instantiate_parametrized_tests(ProcessGroupXCCLTest)

Expand Down