Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions requirements_with_jax_stable_stack_0.6.0.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
absl_py==2.2.2
aqtp==0.9.0
#benchmark_db_writer==1.0.0.dev20250610
#benchmark_db_writer.egg==info
cloud_accelerator_diagnostics==0.1.1
cloud_tpu_diagnostics==0.1.5
chex==0.1.90
datasets==3.6.0
etils==1.12.2
evaluate==0.4.4
flax==0.10.7
grain==0.2.10
grpcio==1.72.0rc1
huggingface_hub==0.33.0
#jax==0.6.0
#jaxlib==0.6.0 # Manually adding to ensure consistency in future
jaxtyping==0.3.2
#jetstream==0.1.0
jsonlines==4.0.0
#libtpu==0.0.15 # Manually adding to ensure consistency in future
matplotlib==3.10.3
ml_collections==1.1.0
ml_dtypes==0.5.1
ml_goodput_measurement==0.0.11
nltk==3.9.1
numpy
omegaconf==2.3.0
optax==0.2.5
orbax-checkpoint
pandas==2.3.0
pathwaysutils@git+https://github.com/AI-Hypercomputer/pathways-utils.git@d6fffd3f9bd5e06c323f5bade04b40fa741c728f
Pillow==11.2.1
protobuf
psutil==7.0.0
pytest==8.4.1
PyYAML==6.0.2
PyYAML==6.0.2
Requests==2.32.4
qwix@git+https://github.com/google/qwix.git
safetensors==0.5.3
sentencepiece==0.2.0
setuptools
tabulate==0.9.0
tensorboard_plugin_profile==2.13.0
tensorboardX==2.6.4
tensorflow==2.19.0
tensorflow_datasets==4.9.9
tensorflow_text==2.19.0
tensorstore==0.1.75
tiktoken==0.9.0
#torch==2.7.1
tqdm==4.67.1
#transformer_engine==2.4.0
transformers==4.52.4
trl==0.19.0
urllib3==2.5.0
2 changes: 2 additions & 0 deletions src/MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def create_orbax_checkpoint_manager(
orbax_logger: Any = None, # pytype: disable=attribute-error
use_ocdbt: bool = True,
use_zarr3: bool = True,
max_to_keep: int = 5,
):
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
if not enable_checkpointing:
Expand Down Expand Up @@ -130,6 +131,7 @@ def create_orbax_checkpoint_manager(
create=True,
save_interval_steps=save_interval_steps,
enable_async_checkpointing=use_async,
max_to_keep = max_to_keep,
),
logger=orbax_logger,
)
Expand Down
12 changes: 8 additions & 4 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward
# and you must set the number of microbatches to at least 2 * num_stages (the minimum 2 * num_stages is set by default with this delay).

model_fsdp_ag_once: False # This controls whether the Zero-1 optimization is active.
# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step.
# This is a memory/time tradeoff - True: This is Zero-1 Sharding. Use ZeroOneTransformer to gather weights once per gradient step.
# False: This is Zero-3 Sharing. Use the standard Transformer, which gathers for each microbatch's fwd/bwd pass.
pipeline_fsdp_ag_once: False # If set to true then all gather all of the weights over FSDP before the first pipeline iteration.
# This is a memory/time tradeoff - we now have to store the FSDP gathered weights and gradients (typically in bf16), as opposed
Expand Down Expand Up @@ -283,7 +283,7 @@ param_scan_axis: 1
# The attention_type parameter determines the variants of attention, e.g. global or local_sliding
attention: 'autoselected' # Supported attention: autoselected, dot_product, flash, cudnn_flash_te
attention_type: 'global' # Supported attention_type: global, local_sliding, chunk, mla
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
attention_bias: False # If True, adds a learnable bias to the query, key, and value projections
attention_sink: False
sliding_window_size: 0
chunk_attn_window_size: 0
Expand Down Expand Up @@ -400,7 +400,7 @@ logical_axis_rules: [
['embed_no_exp', ['fsdp', 'sequence', 'tensor_transpose', 'context']],
['embed_no_exp', ['fsdp', 'fsdp_transpose', 'sequence', 'context']],
['embed_no_exp', ['fsdp', 'sequence', 'context']],
['embed_tensor_transpose', ['tensor_transpose']],
['embed_tensor_transpose', ['tensor_transpose']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'sequence', 'context', 'tensor_transpose', 'expert']],
['q_lora', ['fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert']],
Expand Down Expand Up @@ -500,6 +500,7 @@ train_image_column: 'image'
eval_data_columns: ['text'] # for DPO dataset containing "chosen" and "rejected"
eval_image_column: 'image'
packing: True
max_segments: 32
num_epoch: 1 # only grain and tfds pipeline supports num_epoch > 1
generate_padding_batch_train: False
generate_padding_batch_eval: False
Expand Down Expand Up @@ -853,9 +854,12 @@ vision_output_dim_for_vit: 4096
pixel_shuffle_ratio_for_vit: 0.5
projector_dropout_for_vit: 0.0

# Subslice shape in the form of "x,y,z" when using pathways (single controller).
# Subslice shape in the form of "x,y,z" when using pathways (single controller).
# Example: "8,8" to use a 8x8 subgrid (64 chips) of a full pod (16x16) of trillium.
subslice_shape: ""

# NNX
enable_nnx: false

expert_balance: False
max_to_keep: 5
225 changes: 225 additions & 0 deletions src/MaxText/input_pipeline/CustomPackAndBatchOperation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
"""
Forked from https://github.com/google/grain/blob/7841100258c90c77fcebdd668232aea9c0314fc2/grain/_src/python/experimental/example_packing/packing.py

Customized packing based on MaxText's default
Modified to support max segments per sequence

"""

import dataclasses
from typing import Any, Generic, Iterator, TypeVar, Union, cast

from absl import logging
from grain._src.core import tree_lib
from grain._src.python import record
import numpy as np

_T = TypeVar("_T")


class _PackedBatch:
"""Class to represent a batch of packed examples."""

def __init__(
self,
element_for_shapes: Any, # PyTree[np.ndarray]
batch_size: int,
length_struct: Any, # PyTree[int]
max_segments: int,
):
self._batch_size = batch_size
self._length_struct = length_struct
self._max_segments = max_segments

# Define the main buffers we will pack the data into.
def make_packed_buffer(length: int, input_arr: np.ndarray):
return np.zeros(
shape=(batch_size, length, *input_arr.shape[1:]), # (B, T, ...)
dtype=input_arr.dtype,
)

self._batch = tree_lib.map_structure(
make_packed_buffer, length_struct, element_for_shapes
)

def make_packed_aux_info(length: int):
return np.zeros(shape=(batch_size, length), dtype=np.int32)

self._segmentations = tree_lib.map_structure(
make_packed_aux_info, length_struct
)
self._positions = tree_lib.map_structure(
make_packed_aux_info, length_struct
)

# Tracks the next empty position to insert an example for each row
# in the batch, for each feature in features_to_pack.
self._first_free_cell_per_row = tree_lib.map_structure(
lambda _: np.zeros(batch_size, dtype=np.int32), length_struct
)

# Tracks the number of examples already packed into row of the batch. Used
# to fill the segmentation values for each feature.
self._num_examples_per_row = [0 for _ in range(batch_size)]

# For determinism, the metadata.index for the packed batch must match
# metadata.index of the _last_ included input example.
self._last_record_metadata = None

def get_packed_batch(self) -> record.Record[tuple[_T, _T, _T]]:
assert self._last_record_metadata is not None
return record.Record(
metadata=cast(record.RecordMetadata, self._last_record_metadata),
data=(self._batch, self._segmentations, self._positions),
)

def _can_add_at_row(
self,
element: Any, # PyTree[np.ndarray]
) -> int:
"""Returns the index of the first row which fits element, or -1 if none."""
element_feature_lengths = tree_lib.map_structure(len, element)

# Check no feature exceeds max length
length_exceeded = tree_lib.map_structure(
lambda feature_length, max_length: feature_length > max_length,
element_feature_lengths,
self._length_struct,
)
if any(tree_lib.flatten(length_exceeded)):
raise ValueError(
"Inputs to PackAndBatchOperation must be truncated to max length."
)

# For each row, check whether the total length after adding the current
# element would exceed max feature lengths.
def _feature_will_fit(feature_length, first_free_cell, max_length):
return feature_length + first_free_cell <= max_length

is_row_free_struct = tree_lib.map_structure(
_feature_will_fit,
element_feature_lengths,
self._first_free_cell_per_row,
self._length_struct
)

## Pick first row (if exists) where element can be added.
for i in range(self._batch_size):
if self._num_examples_per_row[i] < self._max_segments:
row_is_free_per_feature = [
free[i] for free in tree_lib.flatten(is_row_free_struct)
]
if all(row_is_free_per_feature):
return i
return -1

def add_element_to_batch(
self,
element: Any, # PyTree[np.ndarray]
row: int,
) -> None:
"""Adds element to current batch at the specified row."""
# Apply updates to each feature.
for per_feature_data in zip(
tree_lib.flatten(element),
tree_lib.flatten(self._batch),
tree_lib.flatten(self._segmentations),
tree_lib.flatten(self._positions),
tree_lib.flatten(self._first_free_cell_per_row),
):
value, batch_value, segmentations, positions, first_free_cell_per_row = (
per_feature_data
)
# Update batch value, segmentations, and positions.
start = first_free_cell_per_row[row]
end = first_free_cell_per_row[row] + len(value)
batch_value[row][start:end] = value
segmentations[row][start:end] = self._num_examples_per_row[row] + 1
positions[row][start:end] = np.arange(end - start)
# Update first_free_cell_per_row.
first_free_cell_per_row[row] += len(value)

self._num_examples_per_row[row] += 1

def try_add_to_batch(self, element: record.Record) -> bool:
"""Finds a row in the batch at which element can be added."""
if (row_idx := self._can_add_at_row(element.data)) == -1:
return False
self.add_element_to_batch(element.data, row_idx)
self._last_record_metadata = element.metadata.remove_record_key()
return True


@dataclasses.dataclass
class CustomPackAndBatchOperation(Generic[_T]):
"""PyGrain pack-and-batch operation - see module docstring.

WARNING: This class is deprecated. Please use
lazy_dataset.FirstFitPackIterDataset instead.

Attributes:
batch_size: int, the batch size.
length_struct: A pytree, with the same structure as `input_iterator`
elements, but where leaves are ints, representing the packed length of the
corresponding feature.
max_segments: int, max segments per sequence

__call__() takes an input iterator, where elements are `Record`s containing:

input_data: Pytrees of arrays. For more info about PyTrees, please refer to:
https://jax.readthedocs.io/en/latest/pytrees.html. Packed leaves should be
n-dimensional arrays, with sequence length as the leading dimension, i.e.
shape (T_in, ...), where T_in < T_packed. Note that leaves can and will
often have ragged length dimensions across different elements of the input
iterator.

The output of __call__() will be an iterator over `Record`s containing a
3-tuple of Pytrees. These are:

data: The batched and packed data. This is a Pytree with parallel structure
to elements of `input_iterator`. Leaves have shape (B, T_packed, ...).
segmentations: Pytree with the same structure as `data`, and leaves of shape
(B, T). Represents which example each entry comes from. This may be used
for Transformer attention masks, for example.
positions: Pytree with the same structure as `data`, and leaves of shape
(B, T). Represents the position of each entry within their original
example. This may be used e.g. in Transformer absolute position
embeddings.
"""

length_struct: Any # PyTree[int]
batch_size: int
max_segments: int
# We don't know input shapes and corresponding buffer shapes until __call__.
_cur_batch: Union[_PackedBatch, None] = None

def __post_init__(self):
logging.error(
"PackAndBatchOperation is deprecated. Please use"
" lazy_dataset.FirstFitPackIterDataset instead."
)

def __call__(
self, input_iterator: Iterator[record.Record[_T]]
) -> Iterator[record.Record[tuple[_T, _T, _T]]]:
for element in input_iterator:
# Use `element` to set dtypes + trailing dimensions.
if self._cur_batch is None: # pytype: disable=attribute-error
self._cur_batch = _PackedBatch(
element.data, self.batch_size, self.length_struct, self.max_segments
)

# Try adding element to the current packed batch.
element_added_to_batch = self._cur_batch.try_add_to_batch(element)

# When we have a full batch, yield the current packed data,
# and then start a new batch with this element.
if not element_added_to_batch:
yield self._cur_batch.get_packed_batch() # Main yield
self._cur_batch = _PackedBatch(
element.data, self.batch_size, self.length_struct, self.max_segments
)
self._cur_batch.try_add_to_batch(element)

# Final batch
yield self._cur_batch.get_packed_batch()
14 changes: 10 additions & 4 deletions src/MaxText/input_pipeline/_hf_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import numpy as np

from .CustomPackAndBatchOperation import CustomPackAndBatchOperation
from MaxText.input_pipeline import _input_pipeline_utils
from MaxText import multihost_dataloading

Expand Down Expand Up @@ -176,6 +177,7 @@ def preprocessing_pipeline(
use_sft=None,
sft_train_on_completion_only=True,
grain_worker_count=1, # only support 0 or 1
max_segments = 1, # max segments per sequence
):
"""pipeline for preprocessing HF dataset"""

Expand Down Expand Up @@ -274,12 +276,14 @@ def lists2array(x):
data_column_names = ("inputs", "targets")

if packing and not use_dpo:
# monkey patch the splitter to handle TE's maximum segment limitation
length_struct = {col: max_target_length for col in data_column_names}
operations.append(
grain.experimental.PackAndBatchOperation(
batch_size=global_batch_size // jax.process_count(),
length_struct=length_struct,
)
CustomPackAndBatchOperation(
batch_size=global_batch_size // jax.process_count(),
length_struct=length_struct,
max_segments=max_segments,
)
)
operations.append(_input_pipeline_utils.ReformatPacking(data_column_names))
else:
Expand Down Expand Up @@ -363,6 +367,7 @@ def make_hf_train_iterator(
use_dpo=config.use_dpo,
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
max_segments=config.max_segments,
)
return train_iter

Expand Down Expand Up @@ -413,5 +418,6 @@ def make_hf_eval_iterator(
use_dpo=config.use_dpo,
use_sft=config.use_sft,
sft_train_on_completion_only=config.sft_train_on_completion_only,
max_segments=config.max_segments,
)
return eval_iter
Loading