|
12 | 12 |
|
13 | 13 | import torch
|
14 | 14 | import torch.distributed as c10d
|
| 15 | +import torch.distributed._functional_collectives as _functional_collectives |
15 | 16 |
|
16 | 17 | if not c10d.is_available() or not c10d.is_xccl_available():
|
17 | 18 | print("c10d XCCL not available, skipping tests", file=sys.stderr)
|
@@ -626,6 +627,86 @@ def test_all_gather_into_tensor(self):
|
626 | 627 | tensor.view(torch.float32),
|
627 | 628 | )
|
628 | 629 |
|
| 630 | + @requires_xccl() |
| 631 | + @skip_if_lt_x_gpu(2) |
| 632 | + def test_unwaited(self) -> None: |
| 633 | + # Verify that the process can terminate gracefully |
| 634 | + # even with unwaited tensors |
| 635 | + store = c10d.FileStore(self.file_name, self.world_size) |
| 636 | + c10d.init_process_group( |
| 637 | + backend="xccl", rank=self.rank, world_size=self.world_size, store=store |
| 638 | + ) |
| 639 | + |
| 640 | + # Case 1: Run collectives under context manager, and don't call wait on them. |
| 641 | + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): |
| 642 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 643 | + input = torch.full( |
| 644 | + (10240, 10240), float(self.rank), device=f"xpu:{self.rank}" |
| 645 | + ) |
| 646 | + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) |
| 647 | + # Non-functional collectives run under the context manager is registered in the work registry. |
| 648 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) |
| 649 | + # Running another collective on the same tensor should still work |
| 650 | + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) |
| 651 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) |
| 652 | + |
| 653 | + # Case 2: Run collectives not under context manager, and don't call wait on them. |
| 654 | + # NOTE: Here we intentionally test memory-stressed case. |
| 655 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) |
| 656 | + for _ in range(50000): |
| 657 | + input = torch.full( |
| 658 | + (1024, 1024), float(self.rank), device=f"xpu:{self.rank}" |
| 659 | + ) |
| 660 | + dist.all_reduce(input, op=dist.ReduceOp.SUM, async_op=True) |
| 661 | + # Work registry size is unchanged, since non-functional collectives not run under |
| 662 | + # the context manager is not registered in the work registry. |
| 663 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 2) |
| 664 | + |
| 665 | + @requires_xccl() |
| 666 | + @skip_if_lt_x_gpu(2) |
| 667 | + def test_wait_tensor(self) -> None: |
| 668 | + # Verify that c10d_functional.wait_tensor() can be invoked on |
| 669 | + # output tensor of non-functional collective |
| 670 | + store = c10d.FileStore(self.file_name, self.world_size) |
| 671 | + c10d.init_process_group( |
| 672 | + backend="xccl", rank=self.rank, world_size=self.world_size, store=store |
| 673 | + ) |
| 674 | + |
| 675 | + # Case 1: under context manager (i.e. work is registered in registry) |
| 676 | + with _functional_collectives.allow_inflight_collective_as_graph_input_ctx(): |
| 677 | + input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") |
| 678 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 679 | + dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) |
| 680 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) |
| 681 | + torch.ops.c10d_functional.wait_tensor(input1) |
| 682 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 683 | + |
| 684 | + input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") |
| 685 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 686 | + work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) |
| 687 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 1) |
| 688 | + work.wait() |
| 689 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 690 | + self.assertEqual(input1, input2) |
| 691 | + |
| 692 | + # Case 2: not under context manager (i.e. work is not registered in registry) |
| 693 | + input1 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") |
| 694 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 695 | + dist.all_reduce(input1, op=dist.ReduceOp.SUM, async_op=True) |
| 696 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 697 | + # this does not take effect, since the underlying wait_tensor() logic would not |
| 698 | + # be able to find the corresponding work object (because it's not registered in registry) |
| 699 | + torch.ops.c10d_functional.wait_tensor(input1) |
| 700 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 701 | + |
| 702 | + input2 = torch.full((10, 10), float(self.rank), device=f"xpu:{self.rank}") |
| 703 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 704 | + work = dist.all_reduce(input2, op=dist.ReduceOp.SUM, async_op=True) |
| 705 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 706 | + work.wait() |
| 707 | + self.assertEqual(torch._C._distributed_c10d._get_work_registry_size(), 0) |
| 708 | + self.assertEqual(input1, input2) |
| 709 | + |
629 | 710 |
|
630 | 711 | instantiate_parametrized_tests(ProcessGroupXCCLTest)
|
631 | 712 |
|
|
0 commit comments