diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 9f72cd1bd..c820a1c48 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -283,6 +283,11 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() { void ProcessGroupXCCL::WorkXCCL::synchronize() { synchronizeStream(); + if (c10d::allow_inflight_collective_as_graph_input()) { + c10d::unregister_work( + c10::intrusive_ptr< + ProcessGroupXCCL::WorkXCCL>::unsafe_reclaim_from_nonowning(this)); + } } void ProcessGroupXCCL::WorkXCCL::synchronizeStream() { diff --git a/test/xpu/distributed/test_c10d_xccl.py b/test/xpu/distributed/test_c10d_xccl.py index 916524073..44a3ac148 100644 --- a/test/xpu/distributed/test_c10d_xccl.py +++ b/test/xpu/distributed/test_c10d_xccl.py @@ -12,6 +12,7 @@ import torch import torch.distributed as c10d +import torch.distributed._functional_collectives as _functional_collectives if not c10d.is_available() or not c10d.is_xccl_available(): print("c10d XCCL not available, skipping tests", file=sys.stderr) @@ -626,6 +627,86 @@ def test_all_gather_into_tensor(self): tensor.view(torch.float32), ) + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_unwaited(self) -> None: + # Verify that the process can terminate gracefully + # even with unwaited tensors + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + + # Case 1: Run collectives under context manager, and don't call wait on them. + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + input = torch.full( + (10240, 10240), float(self.rank), device=f"xpu:{self.rank}" + ) + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) + # Non-functional collectives run under the context manager is registered in the work registry. + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + # Running another collective on the same tensor should still work + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) + + # Case 2: Run collectives not under context manager, and don't call wait on them. + # NOTE: Here we intentionally test memory-stressed case. + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) + for _ in range(50000): + input = torch.full( + (1024, 1024), float(self.rank), device=f"xpu:{self.rank}" + ) + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) + # Work registry size is unchanged, since non-functional collectives not run under + # the context manager is not registered in the work registry. + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) + + @requires_xccl() + @skip_if_lt_x_gpu(2) + def test_wait_tensor(self) -> None: + # Verify that c10d_functional.wait_tensor() can be invoked on + # output tensor of non-functional collective + store = c10d.FileStore(self.file_name, self.world_size) + c10d.init_process_group( + backend="xccl", rank=self.rank, world_size=self.world_size, store=store + ) + + # Case 1: under context manager (i.e. work is registered in registry) + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): + input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + torch.ops.c10d_functional.wait_tensor(input1) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + + input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) + work.wait() + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + self.assertEqual(input1, input2) + + # Case 2: not under context manager (i.e. work is not registered in registry) + input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + # this does not take effect, since the underlying wait_tensor() logic would not + # be able to find the corresponding work object (because it's not registered in registry) + torch.ops.c10d_functional.wait_tensor(input1) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + + input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + work.wait() + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) + self.assertEqual(input1, input2) + instantiate_parametrized_tests(ProcessGroupXCCLTest)