You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am fairly new to jax and am trying to use a version of pmap to perform simple parallelised training, splitting a batch across GPUs, using xarray datasets. However, when monitoring performance with the jax profiler and multiple GPUs, I see that one GPU seems to be significantly slower than the others which causes the other GPUs to block and wait before combining gradients for weight updates (see the profiler screenshot below where the purple ncclDevKernel_AllReduce_Sum blocks waiting for GPU0 to finish computation).
I am passing dummy, 'zero' data to the model so each GPU should be receiving the same data/computing the same gradients (which I've confirmed by inspecting the average gradient computed on each GPU).
Some things I've observed:
I've profiled the GPUs individually (running on a single GPU) without pmap and the single GPU speed is always as fast or faster than the slowest GPU speed that I observe in multi-GPU training (some of the GPUs appear to be faster than others when using a single GPU, but the 'slow' GPU across different multi-GPU runs always takes roughly the same amount of time).
I tried with a varying number of GPUs (2,3,4) and there always seems to be a 'slow' GPU. Before selecting GPUs I'm setting os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" so the ordering of the GPUs should be fixed. It thus seems that a GPU that is 'fast' in one run can become the 'slow' GPU in another run (e.g. GPUs=[1, 2, 3], 1 is slow, 2 and 3 are fast, GPUs=[2, 3], 2 slow, 3 is fast).
The slow GPU always has what appears to be a jax compilation sequence before starting the model training shown in the XLA ops section screenshotted below. This does not happen on the other GPUs however (this is probably expected behaviour but I'm still learning how to read the jax profiler tracer view).
I am quite confused as to why this speed difference exists and can't find anything in jax documentation to imply that this is normal behaviour. Any help would be much appreciated.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I am fairly new to jax and am trying to use a version of pmap to perform simple parallelised training, splitting a batch across GPUs, using xarray datasets. However, when monitoring performance with the jax profiler and multiple GPUs, I see that one GPU seems to be significantly slower than the others which causes the other GPUs to block and wait before combining gradients for weight updates (see the profiler screenshot below where the purple
ncclDevKernel_AllReduce_Sum
blocks waiting for GPU0 to finish computation).I am passing dummy, 'zero' data to the model so each GPU should be receiving the same data/computing the same gradients (which I've confirmed by inspecting the average gradient computed on each GPU).
Some things I've observed:
pmap
and the single GPU speed is always as fast or faster than the slowest GPU speed that I observe in multi-GPU training (some of the GPUs appear to be faster than others when using a single GPU, but the 'slow' GPU across different multi-GPU runs always takes roughly the same amount of time).os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
so the ordering of the GPUs should be fixed. It thus seems that a GPU that is 'fast' in one run can become the 'slow' GPU in another run (e.g. GPUs=[1, 2, 3], 1 is slow, 2 and 3 are fast, GPUs=[2, 3], 2 slow, 3 is fast).I am quite confused as to why this speed difference exists and can't find anything in jax documentation to imply that this is normal behaviour. Any help would be much appreciated.
Versions:
Beta Was this translation helpful? Give feedback.
All reactions