Skip to content

Conversation

Hoomaaan
Copy link
Contributor

@Hoomaaan Hoomaaan commented Sep 26, 2025

Related AWS Neuron ticket: https://t.corp.amazon.com/V1941917988/overview

broadcast was passing group-local ranks directly to xm.collective_broadcast() which expects global ranks, causing data curroption in single-member process groups

TEST:

import os
import torch
import torch.distributed as dist
import torch_xla as xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.runtime as xr

def main():
    dist.init_process_group(backend="xla")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    tp = dist.new_group(ranks=[rank])
    tp_rank = dist.get_rank(group=tp)
    tp_size = dist.get_world_size(group=tp)

    print(
        f">>>> pid={os.getpid()}, rank={rank}\n"
        f">>> world_size={world_size}\n"
        f">>> tp_rank={tp_rank}, tp_size={tp_size}, tp_members={dist.get_process_group_ranks(tp)}"
    )

    do_train, do_valid, do_test = 0.1, 0.2, 0.3
    # breakpoint()
    flags = torch.tensor([do_train, do_valid, do_test], dtype=torch.float32, device='xla')
    # breakpoint()
    dist.broadcast(flags, rank, group=tp)

    print(f">>>> pid={os.getpid()}, rank={rank}\n"
          f">>> do_train={flags[0].item()}, do_valid={flags[1].item()}, do_test={flags[2].item()}\n"
          f">>> global_ordinal={xr.global_ordinal()}")

if __name__ == "__main__":
    main()

Results after this fix:

torchrun --nproc-per-node=2 --nnodes=1 ./bug.py
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] 
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] *****************************************
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0926 18:50:41.903000 1081605 torch/distributed/run.py:766] *****************************************
>>>> pid=1081679, rank=0
>>> world_size=2
>>> tp_rank=0, tp_size=1, tp_members=[0]
>>>> pid=1081680, rank=1
>>> world_size=2
>>> tp_rank=0, tp_size=1, tp_members=[1]
.
.
.
2.19.8089.0+8ab9f450/MODULE_10344927339446294134+e30acd3a/model.neff
>>>> pid=1081680, rank=1
>>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896
>>> global_ordinal=1
>>>> pid=1081679, rank=0
>>> do_train=0.10000000149011612, do_valid=0.20000000298023224, do_test=0.30000001192092896

Now both ranks have the correct values. Previously Rank1 was all zeros.

broadcast was passing group-local ranks directly to xm.collective_broadcast() which expects global ranks, causing data curroption in single-member process groups
@jeffhataws jeffhataws requested review from bhavya01 and pgmoka and removed request for bhavya01 September 29, 2025 16:29
@ysiraichi
Copy link
Collaborator

@bhavya01 I'm not completely familiar with the distributed codebase. Could you take a look at it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants