A simple, unified multimodal models training engine. Lean, flexible, and built for hacking at scale.
Quick Start β’ Examples β’ Model Support β’ Optimizations β’ Codebase Architecture β’ Documentation
- [2025-10] ππ Efficiency Report: We provide comprehensive Model FLOPs Utilization (MFU) metrics for various model architectures and training configurations. See MFU Reference for detailed benchmarks.
- [2025-10] ππ LMMs-Engine v0.1 is here! a lean, efficient framework built to train unified multimodal model at scale.
# Clone the repository
git clone https://github.com/LMMs-Lab/lmms-engine.git
cd lmms-engine
# Install dependencies
uv sync
# Optional: Performance optimizations
uv pip install flash-attn --no-build-isolation
uv pip install liger-kernelRecommended: torchrun (native PyTorch)
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
--master_addr=127.0.0.1 --master_port=12355 \
-m lmms_engine.launch.cli config_yaml=examples/qwen3_vl/example_config.yamlAlternative: Accelerate
accelerate launch --use_fsdp \
-m lmms_engine.launch.cli config_yaml=examples/qwen3_vl/example_config.yamlSingle GPU
python -m lmms_engine.launch.cli config_yaml=examples/qwen3_vl/example_config.yaml| Model | Quick Start | FSDP2 | USP | Muon | Liger | Packing | NSA | Highlights |
|---|---|---|---|---|---|---|---|---|
| BAGEL | run.sh | β | TBD | β | β | β | β | Unified visual understanding & generation |
| Qwen2.5 | run.sh | β | β | β | β | β | β | Large Language Model |
| Qwen2.5-VL | run.sh | β | β | β | β | β | β | Multimodal Model |
| Qwen2.5-Omni | run.sh | β | β | β | β | β | β | Unified multimodal (image, audio, text) |
| Qwen3-VL | run.sh | β | β | β | β | β | β | Native-resolution, long context (10K+ tokens) |
| WanVideo | run.sh | β | β | β | β | β | β | T2V/I2V/V2V generation (1.3B/14B) |
| FLA models | run.sh | β | β | β | β | β | β | Efficient architecture, FineWeb-Edu pretraining |
| dLLM (Qwen3) | run.sh | β | β | β | β | β | β | Masked diffusion language model |
| RAE-SigLip | run.sh | β | β | β | β | β | β | Representation AutoEncoder, LPIPS, EMA |
| SiT | run.sh | β | β | β | β | β | β | Interpolant Transformer, CFG, ImageNet-1K |
Optimization Legend:
- FSDP2: Fully Sharded Data Parallel v2 for distributed training
- USP: Ulysses Sequence Parallel for long contexts
- Muon: Advanced optimizer with Newton-Schulz orthogonalization
- Liger: Triton fused kernels (CrossEntropy, RMSNorm, RoPE, SwiGLU) for 30% memory reduction
- Packing: First-fit bin packing for peaking at 35-40% MFU vs 20-25% (w/o in Qwen2.5-VL finetuning)
- NSA: Native Sparse Attention for efficient long-context processing
π‘ Tip: Each
run.shfile contains detailed setup instructions, prerequisites, and configuration options.
19+ architectures spanning vision-language, diffusion, and language models.
- Qwen2.5-VL - SOTA level performance vision-language model
- Qwen3-VL - SOTA level performance vision-language model
- Qwen2.5-Omni - Unified vision + audio + text modalities
- LLaVA-OneVision - Fully open-source vision-language model
- Bagel - Unified multimodal model for visual understanding and generation
- Aero - Lightweight audio-language model
- dLLM (Qwen3) - Diffusion Language Model with masked prediction
- WanVideo (1.3B/14B) - Text/Image-to-Video generation (T2V/I2V/V2V)
- SiT (XL/2) - Scalable Interpolant Transformers for class-conditional image generation
- RAE-SigLip - Representation AutoEncoder with adversarial discriminator
- Qwen2/2.5/3 series - Full Liger kernel support with fused operations
- Linear Attention Models - Recurrent architecture optimized for Muon; Please install FLA first.
- Custom architectures - Extensible via
@register_model()decorator
Production-grade efficiency from distributed training to kernel fusion.
-
FSDP2 - PyTorch 2.0+ DTensor-based sharding for parameters, gradients, and optimizer states. Improved composability over original FSDP enables flexible parallelism composition.
-
Ulysses Sequence Parallel - Splits sequence dimension across GPUs for ultra-long contexts. Critical for vision-language models like Qwen3-VL with 10K+ visual tokens.
-
Multi-dimensional Parallelism - Compose TP x PP Γ DP meshes for cluster-scale training.
-
Flash Attention + Unpadding - Tiled attention with
use_rmpadeliminates all padding computation. 2-3Γ speedup on variable-length sequences. -
Native Sparse Attention (NSA) - Hybrid attention mechanism combining compressed attention, topk sparse attention, and sliding window attention. Enables efficient long-context processing for BAGEL model with reduced memory footprint.
-
Liger Kernel - Triton fused kernels (CrossEntropy, RMSNorm, RoPE, SwiGLU) achieve 30% memory reduction by avoiding intermediate materializations.
-
Monkey Patching System - Runtime kernel injection via
lmms_engine/configs/monkey_patch/for model-specific optimizations without code modification. -
Sequence Packing - First-fit bin packing achieves 35-40% MFU vs 20-25% without packing. Combined with unpadding for zero padding waste.
- Muon Optimizer - Newton-Schulz orthogonalization with Triton kernels, distributed via DTensor. Selective 2D-parameter application outperforms AdamW convergence.
- Streaming Datasets -
IterableDatasetfor trillion-token pretraining without full data loading.
Sequence Packing - with full unpadding
dataset_config:
packing: true
packing_strategy: first_fit
packing_length: 32000
trainer_args:
use_rmpad: true # Requires flash-attn
use_liger_kernel: trueLiger Kernel - Enable LinkedIn's Triton kernels for 30% memory reduction
trainer_args:
use_liger_kernel: trueFused operations:
- CrossEntropy (major memory savings)
- RMSNorm, RoPE, SwiGLU
- Automatically applied via monkey patching
Muon Optimizer - State-of-the-art optimizer for LLMs
trainer_args:
use_muon: true # enable muonwithadam optimizer
adam_beta1: 0.9 # for the adam part in muonwithadam optimizer
adam_beta2: 0.999 # for the adam part in muonwithadam optimizer
adam_epsilon: 1.0e-8 # for the adam part in muonwithadam optimizer
learning_rate: 0.001
weight_decay: 0.01
# ns_steps: 5 # Newton-Schulz iterations (default)
# for some modules which the user hope to Features:
- Newton-Schulz orthogonalization with Triton kernels
- Distributed via DTensor (FSDP2)
- Selective 2D parameter application
Note
If users wish to specify whether a module should be optimized using Muon or Adam, they can designate this in lmms_engine.train.hf.trainer.create_optimizer. By default, modules excluded from Muon optimization include those containing the following substrings in their names: ["emb", "norm", "lm_head", "bias", "wte", "wpe", "output", "a_proj", "b_proj", "conv1d", "rotary"]
as well as any parameters whose dimension does not equal 2.
FSDP2 Configuration
trainer_args:
fsdp2: true
fsdp_config:
transformer_layer_cls_to_wrap: ["Qwen2VLDecoderLayer"]
reshard_after_forward: false
activation_checkpointing: trueUlysses Sequence Parallel - For long-sequence VLMs
trainer_args:
sp_ulysses_degree: 2 # Sequence parallel degreeBenefits:
- Splits sequence length across GPUs
- Reduces memory footprint for long contexts
- Works with Flash Attention
Native Sparse Attention (NSA) - Efficient long-context attention for BAGEL
model_config:
load_from_pretrained_path: "lmms-lab/BAGEL-7B-MoT-ver.LE"
monkey_patch:
- type: nsa
model_type: bagel
kwargs:
block_size: 64
compress_type: "weightedpool" # weightedpool, linear, avgpool
kernel_size: 32
kernel_stride: 16
topk: 16
init_blocks: 1
local_blocks: 2
window_size: 512Features:
- Compressed attention with key-value compression
- TopK sparse attention for efficiency
- Sliding window attention for local context
- Hybrid mechanism combines all three attention types
- Requires:
pip install git+https://github.com/XunhaoLai/native-sparse-attention-triton.git
Note: Currently only supported for BAGEL model.
-
Process the dataset into OpenAI chat format (JSONL/JSON/Arrow/CSV)
hf download kcz358/open-thoughts-debug --local-dir data/open_thoughts_debug --repo-type dataset
-
Prepare dataset YAML (optional for single data source)
datasets: - path: data/open_thoughts_debug data_folder: "" data_type: arrow
-
Configure training - See examples/qwen3_vl/example_config.yaml or any model-specific config in examples/
Getting Started:
- Dataset Preparation - How to prepare and structure your data
- Dataset & Packing Guide - Detailed dataset implementations and packing strategies
- Training Guide - Comprehensive training walkthrough
Advanced Topics:
- Design Principles - Architectural patterns and philosophy
- API Reference - Detailed API documentation
Factory Pattern enables easy extensibility:
# Register a custom dataset
from lmms_engine.datasets import register_dataset, BaseDataset
@register_dataset("my_custom_dataset")
class MyCustomDataset(BaseDataset):
def __init__(self, config):
super().__init__(config)
# Custom initialization
def __getitem__(self, idx):
# Custom data loading
return item
# Register a custom processor
from lmms_engine.datasets.processor import register_processor
@register_processor("my_custom_processor")
class MyCustomProcessor:
def __call__(self, raw_data):
# Custom processing
return processed_dataBuilder Pattern for flexible composition:
from lmms_engine.train import TrainRunner
# Configuration defines the pipeline
runner = TrainRunner(config)
runner.build() # Lazy initialization of components
runner.run() # Execute trainingPipeline stages:
- Model initialization - From pretrained or config
- Dataset creation - With processor and collator
- Monkey patching - Apply kernel optimizations
- Trainer setup - FSDP2, DeepSpeed, or custom
- Training execution - With checkpointing and logging
| Trainer Type | Use Case | Key Features |
|---|---|---|
hf_trainer |
General VLM/LM training | FSDP2, Muon, Liger, Flash Attn |
dllm_trainer |
Diffusion language models | Masked LM, custom loss, DLLM collator |
wan_trainer |
Video generation | Flow-matching, multi-modal inputs |
rae_trainer |
Visual autoencoders | Adversarial loss, EMA, LPIPS |
sit_trainer |
Diffusion transformers | Interpolant framework, CFG, EMA |
- Vision-Language Pretraining - Qwen-VL, LLaVA on large multimodal datasets
- Video Understanding - AERO on 3D video data
- Diffusion Models - DLLM, SiT, WanVideo for generation tasks
- Representation Learning - RAE for visual representations
- Language Model Pretraining - DGN, Qwen with Muon optimizer
- Multimodal Fine-tuning - Efficient SFT with sequence packing
We welcome contributions! Please see our Design Principles for coding guidelines:
- Simplicity: Write simple, straightforward code
- Readability: Prioritize clarity over cleverness
- Testability: Create testable components
- Minimal Changes: Only modify code related to the task
- Less Code = Less Debt: Minimize code footprint
Thanks to the following projects for their excellent work:
If you use LMMs Engine in your research, please cite:
@software{lmms_engine2024,
title={LMMs Engine: A simple, unified multimodal framework for pretraining and finetuning.},
author={LMMs-Lab},
year={2024},
url={https://github.com/LMMs-Lab/lmms-engine}
}This project is licensed under the Apache 2.0 License - see the LICENSE file for details.
- GitHub: https://github.com/EvolvingLMMs-Lab/lmms-engine
- LMMs-Lab: https://lmms-lab.com
- Documentation: docs/
- Issues: https://github.com/EvolvingLMMs-Lab/lmms-engine/issues
Built with β€οΈ by LMMs-Lab
β Star us on GitHub to support the project! β