Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def patch_world(rank, size):
yield


@contextlib.contextmanager
def patch_world_with_xla_runtime(rank, size):
assert isinstance(dist.group.WORLD,
torch_xla.distributed.xla_backend.ProcessGroupXla)
with mock.patch.object(
dist.group.WORLD, 'rank', return_value=rank), mock.patch.object(
dist.group.WORLD, 'size', return_value=size), mock.patch.object(
xr, 'global_ordinal', return_value=rank), mock.patch.object(
xr, 'world_size', return_value=size):
yield


class XlaBackendTest(parameterized.TestCase):

@classmethod
Expand Down Expand Up @@ -328,6 +340,81 @@ def test_unimplemented_op(self, op):
with self.assertRaises(NotImplementedError):
getattr(pg_xla, op)(tensor)

@patch_world_with_xla_runtime(rank=0, size=2)
def test_broadcast_single_rank_group_rank0(self):
"""Test broadcast in single-member process group for rank 0"""
device = torch_xla.device()

with new_group_barrier_disabled():
tp = dist.new_group(ranks=[0])

# Create flags tensor with initial values (simulating rank 0's values)
flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device)

# Broadcast within the single-member group (should be a no-op but shouldn't crash)
dist.broadcast(flags, src=0, group=tp)

# Values should remain unchanged since it's a single-member group
self.assertAlmostEqual(flags[0].item(), 0.1, places=5)
self.assertAlmostEqual(flags[1].item(), 0.2, places=5)
self.assertAlmostEqual(flags[2].item(), 0.3, places=5)

# Verify the process group properties
self.assertEqual(dist.get_rank(group=tp), 0)
self.assertEqual(dist.get_world_size(group=tp), 1)

@patch_world_with_xla_runtime(rank=1, size=2)
def test_broadcast_single_rank_group_rank1(self):
"""Test broadcast in single-member process group for rank 1"""
device = torch_xla.device()

with new_group_barrier_disabled():
tp = dist.new_group(ranks=[1])

# Create flags tensor with initial values (simulating rank 1's values)
flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device)

# Broadcast within the single-member group (should be a no-op but shouldn't crash)
dist.broadcast(flags, src=1, group=tp)

# Values should remain unchanged since it's a single-member group
self.assertAlmostEqual(flags[0].item(), 0.1, places=5)
self.assertAlmostEqual(flags[1].item(), 0.2, places=5)
self.assertAlmostEqual(flags[2].item(), 0.3, places=5)

# Verify the process group properties
self.assertEqual(dist.get_rank(group=tp),
0) # Local rank in single-member group is 0
self.assertEqual(dist.get_world_size(group=tp), 1)

@patch_world_with_xla_runtime(rank=0, size=2)
def test_broadcast_global_rank_conversion_single_member(self):
"""Test that global rank conversion works correctly for single-member groups"""
device = torch_xla.device()

# Create single-member group for rank 0
with new_group_barrier_disabled():
tp = dist.new_group(ranks=[0])

flags = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, device=device)

# Get the ProcessGroupXla instance to test directly
self.assertIsInstance(tp, torch_xla.distributed.xla_backend.ProcessGroupXla)

# Test broadcast options - local rank 0 should map to global rank 0
opts = dist.BroadcastOptions()
opts.rootRank = 0
opts.rootTensor = 0

# This should work without variable name errors
work = tp.broadcast([flags], opts)
self.assertIsNotNone(work)

# Values should be preserved
self.assertAlmostEqual(flags[0].item(), 0.1, places=5)
self.assertAlmostEqual(flags[1].item(), 0.2, places=5)
self.assertAlmostEqual(flags[2].item(), 0.3, places=5)


if __name__ == '__main__':
if xr.device_type() != 'CPU':
Expand Down
7 changes: 6 additions & 1 deletion torch_xla/distributed/xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,14 @@ def allgather_coalesced(self, output_tensors_list, input_tensors, opts=None):
# Call site:
# https://github.com/pytorch/pytorch/blob/release/1.10/torch/distributed/distributed_c10d.py#L1129
def broadcast(self, tensors, opts):
import torch.distributed as dist

root_tensor = tensors[opts.rootTensor]
# Convert group local rank to global rank for xla collectives
group_source = opts.rootRank
global_src = dist.get_global_rank(self, group_source)
xm.collective_broadcast([root_tensor],
opts.rootRank,
global_src,
groups=self._mesh,
pin_layout=False)

Expand Down
Loading