Skip to content

Commit 48586d3

Browse files
RecML authorsrecml authors
authored andcommitted
[Efficient LMs] Add compressor and decompressor for the research work of go/context_compression. This commit adds functions for compressing and decompressing input and output tensors.
PiperOrigin-RevId: 781578682
1 parent 847628b commit 48586d3

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

recml/core/training/jax_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,14 +699,16 @@ def train(self, task: JaxTask) -> core.Logs:
699699
)
700700
metrics[core.TRAIN_LOG_DIRNAME] = train_metrics
701701

702+
if jax.process_index() == 0:
703+
task.export_model(state, self._model_dir)
704+
702705
self._maybe_save_checkpoint(curr_step, state, metrics=metrics)
703706
step = curr_step + 1
704707

705708
self.checkpoint_manager.wait_until_finished()
706709

707710
if jax.process_index() == 0:
708711
self._write_marker_file()
709-
task.export_model(state, self._model_dir)
710712

711713
self.checkpoint_manager.close()
712714
del self.checkpoint_manager

recml/core/training/partitioning.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array:
107107
def partition_init(
108108
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
109109
) -> CreateStateFn:
110-
with jax.sharding.use_mesh(self.mesh):
110+
with jax.set_mesh(self.mesh):
111111
if abstract_batch is not None:
112112
abstract_state = jax.eval_shape(init_fn, abstract_batch)
113113
specs = nn.get_partition_spec(abstract_state)
@@ -117,7 +117,7 @@ def partition_init(
117117
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)
118118

119119
def _wrapped_init(batch: PyTree) -> State:
120-
with jax.sharding.use_mesh(self.mesh):
120+
with jax.set_mesh(self.mesh):
121121
state = init_fn(batch)
122122
state = _maybe_unbox_state(state)
123123
return state
@@ -130,15 +130,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
130130
jit_kws["out_shardings"] = (self.state_sharding, None)
131131
jit_kws["donate_argnums"] = (1,)
132132

133-
with jax.sharding.use_mesh(self.mesh):
133+
with jax.set_mesh(self.mesh):
134134
step_fn = jax.jit(
135135
fn,
136136
in_shardings=(self.data_sharding, self.state_sharding),
137137
**jit_kws,
138138
)
139139

140140
def _wrapped_step(batch: PyTree, state: State) -> Any:
141-
with jax.sharding.use_mesh(self.mesh):
141+
with jax.set_mesh(self.mesh):
142142
return step_fn(batch, state)
143143

144144
return _wrapped_step
@@ -217,7 +217,7 @@ def __init__(
217217
def mesh_context_manager(
218218
self,
219219
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
220-
return jax.sharding.use_mesh
220+
return jax.set_mesh
221221

222222
def shard_inputs(self, inputs: PyTree) -> PyTree:
223223
def _shard(x: np.ndarray) -> jax.Array:

recml/layers/linen/sparsecore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ class SparsecoreEmbed(nn.Module):
362362
Attributes:
363363
sparsecore_config: A sparsecore config specifying how to create the tables.
364364
mesh: The mesh to use for the embedding layer. If not provided, the global
365-
mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an
365+
mesh set by `jax.set_mesh` will be used. If neither is set, an
366366
error will be raised.
367367
"""
368368

@@ -375,7 +375,7 @@ def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh:
375375
abstract_mesh = jax.sharding.get_abstract_mesh()
376376
if not abstract_mesh.shape_tuple:
377377
raise ValueError(
378-
'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make'
378+
'No abstract mesh shape was set with `jax.set_mesh`. Make'
379379
' sure to set the mesh when calling the sparsecore module.'
380380
)
381381
return abstract_mesh

0 commit comments

Comments
 (0)