Is it normal that different CPUs output different deterministic results for the Same Code? #29008
-
Hi, I have a program that uses jax, flax and optax. I ran this code on CPU, using three different CPUs. Two MacOs Systems (one on Sequoia (M1 Pro), other on Sonoma (M2)) and one on a linux system. All three systems output different results for the same output, however they output that output deterministically. Is there any solution to this? Is this a bug in jax, or is this related to XLA? If there are any solutions to this I would be grateful. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 5 replies
-
It's hard to give a concrete answer without a minimal reproducible example, but it's not unexpected that different platforms could exhibit deterministically different numerics. And these can even result in large errors in absolute terms when accumulated. Without more details, it's not obvious that this is a bug, but it could be! If you can narrow it down and produce a minimal pure-JAX example of these errors, I'd be happy to help diagnose. |
Beta Was this translation helpful? Give feedback.
-
Hello, thanks for your reply! Here is as narrow as I could make it.
This outputs on a macos system: -1.9979573829398634 and on a linux system: -1.9979573808129485 Differing in the last 8 digits, and I suppose given a much larger complicated system this difference can be quite large. Also note that sometimes differences between two devices can be larger than between two other devices. |
Beta Was this translation helpful? Give feedback.
This is unrelated to
jax.jit
– you're getting different results on different platforms (both with and without JIT) because the different platforms are executing using different LAPACK libraries.