Skip to content

Conversation

susanbao
Copy link
Collaborator

@susanbao susanbao commented Oct 9, 2025

The newer version of JAX 0.7.2 and Flax 0.12.0 now strictly requires a mesh to be defined whenever you initialize parameters with sharding rules, even in a single-device unit test environment. Our unit tests failed for this.

For Flax team, the issue is due to this change: https://github.com/google/flax/blob/main/docs_nnx/flip/4844-var-eager-sharding.md

Simplify the creation of sharded NNX models. When a sharding annotation is provided, all nnx.Variable creation will require a mesh context and automatically be sharded as annotated.

It can be disabled by using flax.config.update('flax_always_shard_variable', False)

Copy link

github-actions bot commented Oct 9, 2025

@susanbao susanbao changed the title Flax Fix Unit test failure for JAX/Flax version update Oct 9, 2025
@susanbao susanbao merged commit 972b4ff into main Oct 9, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants