Skip to content

Commit 70c8846

Browse files
authored
add unregister wait_tensor (#2019)
Refer #2019, support allow_inflight_collective_as_graph unregister
1 parent a5671d2 commit 70c8846

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

src/xccl/ProcessGroupXCCL.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ bool ProcessGroupXCCL::WorkXCCL::isCompleted() {
283283

284284
void ProcessGroupXCCL::WorkXCCL::synchronize() {
285285
synchronizeStream();
286+
if (c10d::allow_inflight_collective_as_graph_input()) {
287+
c10d::unregister_work(
288+
c10::intrusive_ptr<
289+
ProcessGroupXCCL::WorkXCCL>::unsafe_reclaim_from_nonowning(this));
290+
}
286291
}
287292

288293
void ProcessGroupXCCL::WorkXCCL::synchronizeStream() {

test/xpu/distributed/test_c10d_xccl.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.distributed as c10d
15+
import torch.distributed._functional_collectives as _functional_collectives
1516

1617
if not c10d.is_available() or not c10d.is_xccl_available():
1718
print("c10d XCCL not available, skipping tests", file=sys.stderr)
@@ -626,6 +627,86 @@ def test_all_gather_into_tensor(self):
626627
tensor.view(torch.float32),
627628
)
628629

630+
@requires_xccl()
631+
@skip_if_lt_x_gpu(2)
632+
def test_unwaited(self) -> None:
633+
# Verify that the process can terminate gracefully
634+
# even with unwaited tensors
635+
store = c10d.FileStore(self.file_name, self.world_size)
636+
c10d.init_process_group(
637+
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
638+
)
639+
640+
# Case 1: Run collectives under context manager, and don't call wait on them.
641+
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
642+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
643+
input = torch.full(
644+
(10240, 10240), float(self.rank), device=f"xpu:{self.rank}"
645+
)
646+
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
647+
# Non-functional collectives run under the context manager is registered in the work registry.
648+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
649+
# Running another collective on the same tensor should still work
650+
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
651+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
652+
653+
# Case 2: Run collectives not under context manager, and don't call wait on them.
654+
# NOTE: Here we intentionally test memory-stressed case.
655+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
656+
for _ in range(50000):
657+
input = torch.full(
658+
(1024, 1024), float(self.rank), device=f"xpu:{self.rank}"
659+
)
660+
dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True)
661+
# Work registry size is unchanged, since non-functional collectives not run under
662+
# the context manager is not registered in the work registry.
663+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2)
664+
665+
@requires_xccl()
666+
@skip_if_lt_x_gpu(2)
667+
def test_wait_tensor(self) -> None:
668+
# Verify that c10d_functional.wait_tensor() can be invoked on
669+
# output tensor of non-functional collective
670+
store = c10d.FileStore(self.file_name, self.world_size)
671+
c10d.init_process_group(
672+
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
673+
)
674+
675+
# Case 1: under context manager (i.e. work is registered in registry)
676+
with _functional_collectives.allow_inflight_collective_as_graph_input_ctx():
677+
input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
678+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
679+
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
680+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
681+
torch.ops.c10d_functional.wait_tensor(input1)
682+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
683+
684+
input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
685+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
686+
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
687+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1)
688+
work.wait()
689+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
690+
self.assertEqual(input1, input2)
691+
692+
# Case 2: not under context manager (i.e. work is not registered in registry)
693+
input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
694+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
695+
dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True)
696+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
697+
# this does not take effect, since the underlying wait_tensor() logic would not
698+
# be able to find the corresponding work object (because it's not registered in registry)
699+
torch.ops.c10d_functional.wait_tensor(input1)
700+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
701+
702+
input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}")
703+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
704+
work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True)
705+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
706+
work.wait()
707+
self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0)
708+
self.assertEqual(input1, input2)
709+
629710

630711
instantiate_parametrized_tests(ProcessGroupXCCLTest)
631712

0 commit comments

Comments
 (0)