Skip to content

Commit 5016caf

Browse files
patnotzGoogle-ML-Automation
authored andcommitted
Add a Flax NNX layer and supporting code
PiperOrigin-RevId: 820445685
1 parent 2121dbf commit 5016caf

File tree

12 files changed

+1340
-36
lines changed

12 files changed

+1340
-36
lines changed

jax_tpu_embedding/sparsecore/examples/models/shakespeare/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,17 @@ pytype_strict_library(
5757
pypi_requirement("jax"),
5858
],
5959
)
60+
61+
pytype_strict_library(
62+
name = "flax_nnx_model",
63+
srcs = [
64+
"flax_nnx_model.py",
65+
],
66+
deps = [
67+
"//jax_tpu_embedding/sparsecore/lib/flax/nnx:embed",
68+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
69+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
70+
pypi_requirement("flax/nnx"),
71+
pypi_requirement("jax"),
72+
],
73+
)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Shakespeare model using embedding layer."""
15+
16+
from flax import nnx
17+
import jax
18+
import jax.numpy as jnp
19+
from jax_tpu_embedding.sparsecore.lib.flax.nnx import embed
20+
from jax_tpu_embedding.sparsecore.lib.nn import embedding
21+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
22+
23+
Nested = embedding.Nested
24+
25+
26+
################################################################################
27+
# Define the model.
28+
################################################################################
29+
class Model(nnx.Module):
30+
"""Shakespeare model using embedding layer."""
31+
32+
def __init__(
33+
self,
34+
*,
35+
feature_specs: Nested[embedding_spec.FeatureSpec],
36+
global_batch_size: int,
37+
vocab_size: int,
38+
seq_len: int,
39+
embedding_size: int,
40+
enable_minibatching: bool,
41+
mesh: jax.sharding.Mesh,
42+
sharding_axis: str,
43+
):
44+
self.feature_name = 'shakespeare_feature'
45+
assert len(feature_specs) == 1, 'Shakespeare model expects one feature.'
46+
assert self.feature_name in feature_specs, (
47+
'Shakespeare model expects feature named "%s".' % self.feature_name
48+
)
49+
50+
self.feature_specs = feature_specs
51+
self.global_batch_size = global_batch_size
52+
self.vocab_size = vocab_size
53+
self.seq_len = seq_len
54+
self.embedding_size = embedding_size
55+
self.enable_minibatching = enable_minibatching
56+
self.mesh = mesh
57+
self.sharding_axis = sharding_axis
58+
rngs = nnx.Rngs(params=42)
59+
self.embedding_layer = embed.SparseCoreEmbed(
60+
feature_specs=self.feature_specs,
61+
mesh=self.mesh,
62+
sharding_axis=self.sharding_axis,
63+
rngs=rngs,
64+
enable_minibatching=enable_minibatching,
65+
)
66+
e = self.embedding_size
67+
v = self.vocab_size
68+
s = self.seq_len
69+
self.dense_layer_1 = nnx.Linear(
70+
in_features=s * e,
71+
out_features=e,
72+
rngs=rngs,
73+
)
74+
self.dense_layer_2 = nnx.Linear(
75+
in_features=e,
76+
out_features=v,
77+
rngs=rngs,
78+
)
79+
80+
def add_sharding_constraint(self, x: jax.Array, names: tuple[str | None]):
81+
# Add a sharding constraint to the array.
82+
#
83+
# Add a sharding constraint to the array to ensure that the sharding
84+
# information is not lost during compilation. This may not be necessary but
85+
# it helps SPMD and ensures that the sharding information is as expected.
86+
#
87+
# Args:
88+
# x: The array to add the sharding constraint to.
89+
# names: The mesh axes for the partition spec.
90+
#
91+
# Returns:
92+
# The array with the sharding constraint added.
93+
return jax.lax.with_sharding_constraint(
94+
x,
95+
jax.sharding.NamedSharding(
96+
self.mesh, jax.sharding.PartitionSpec(*names)
97+
),
98+
)
99+
100+
def __call__(self, embedding_lookup_inputs: embedding.PreprocessedInput):
101+
# Run the embedding layer.
102+
x = self.embedding_layer(embedding_lookup_inputs)
103+
104+
# Unpack the activations.
105+
x = x[self.feature_name]
106+
x = jnp.reshape(x, (self.global_batch_size, -1))
107+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
108+
109+
# Apply the dense portion of the model.
110+
x = self.dense_layer_1(x)
111+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
112+
x = self.dense_layer_2(x)
113+
x = self.add_sharding_constraint(x, (self.sharding_axis,))
114+
115+
return x

jax_tpu_embedding/sparsecore/lib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pytype_strict_library(
2727
"//jax_tpu_embedding/sparsecore/lib/core", # buildcleaner: keep
2828
"//jax_tpu_embedding/sparsecore/lib/fdo", # buildcleaner: keep
2929
"//jax_tpu_embedding/sparsecore/lib/flax", # buildcleaner: keep
30+
"//jax_tpu_embedding/sparsecore/lib/flax/nnx", # buildcleaner: keep
3031
"//jax_tpu_embedding/sparsecore/lib/nn", # buildcleaner: keep
3132
"//jax_tpu_embedding/sparsecore/lib/proto", # buildcleaner: keep
3233
],
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "EXTERNAL_USERS")
15+
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
16+
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_library")
17+
18+
package(
19+
default_applicable_licenses = ["//:license"],
20+
default_visibility = EXTERNAL_USERS,
21+
)
22+
23+
pytype_strict_library(
24+
name = "embed",
25+
srcs = [
26+
"embed.py",
27+
],
28+
deps = [
29+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding",
30+
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
31+
"//jax_tpu_embedding/sparsecore/utils",
32+
pypi_requirement("flax/nnx"),
33+
pypi_requirement("jax"),
34+
pypi_requirement("optax"),
35+
],
36+
)
37+
38+
# Library target.
39+
pytype_strict_library(
40+
name = "nnx",
41+
srcs = ["__init__.py"],
42+
visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"],
43+
deps = [
44+
":embed", # buildcleaner: keep
45+
],
46+
)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2024 The JAX SC Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# Empty file needed by setuptools.find_packages to recognize this as a package.

0 commit comments

Comments
 (0)