Training Diffusion Large Language Models Made Simple
dLLM is an educational library offering unified implementations for training diffusion language models. It brings transparency to the entire training and deployment process, making reproduction and finetuning of open-weight diffusion language models much easier. Below are some of the key features that make dLLM special:
-
dLLM provides reproduction and finetuning recipes for a variety of open-weight models (e.g., LLaDA, Dream and RND1), and provides reference implementation of various training algorithms (e.g., Edit Flows).
-
dLLM, built on top of 🤗 Transformers, scales seamlessly—from edge devices with LoRA to multi-node clusters with DeepSpeed and beyond.
-
dLLM provides unified, modular training pipelines (inspired by 🤗 Transformers Trainer) and well-documented examples, making customization simple and development highly user-friendly.
Note
This repository is primarily for educational purposes and does not aim for 100% exact reproduction of official models (which is impossible). We hope it serves as a helpful reference for the community — contributions and improvements are always welcome!
-
examples/llada: Finetuning open-weight LLaDA LLaDA / LLaDA-MoE, as well as reproducing LLaDA by training from scratch on public data (pretraining & finetuning). -
examples/dream: Finetuning open-weight Dream Dream, as well as reproducing Dream by training from scratch on public data (pretraining & finetuning). -
examples/rnd: (WIP) Finetuning open-weight RND1 RND1-Base. -
examples/editflow: Educational reference for training EditFlow models, demonstrating how to extend existing DLLMs (e.g., LLaDA and Dream) with edit operations—insertion, deletion, and substitution—and how to pretrain or finetune EditFlow models from scratch on public data. -
More upcoming — see Roadmap.
# create and activate conda environment
conda create -n dllm python=3.10 -y
conda activate dllm
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
--index-url https://download.pytorch.org/whl/cu124
# install requirements
pip install -r requirements.txt
# install dllm package
pip install -e .For Slurm users, update scripts/train.slurm.sh for your cluster:
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
+ #SBATCH --partition=YOUR_PARTITION
+ #SBATCH --quotatype=YOUR_QUOTATYPENext, create a directory for your job logs:
mkdir logsThis folder will store the log files generated by your sbatch jobs.
# modules for training / sampling
dllm
├── core # Core reusable modules shared across `dllm/pipelines`
│ ├── schedulers
│ └── trainers
├── data
├── pipelines # Application-specific training & inference pipelines
│ ├── dream
│ ├── editflow
│ └── llada
│ ├── models # Model architecture and configs
│ ├── generate.py # Generation utilities
│ └── trainer.py # Core training logic
├── tools
└── utils
# entry points for training / sampling
examples
├── dream
├── editflow
└── llada
├── generate.py # Generation example
├── pt.py # Pretraining example
├── README.md # Example-level documentations
└── sft.py # SFT example
A typical training entry script look like (for example, examples/llada/sft.py) looks like this:
import transformers
import dllm
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model=model, model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
dataset = "..."
# ----- Training --------------------------------------------------------------
trainer = dllm.pipelines.llada.LLaDATrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
pad_to_multiple_of=8,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id, # LLaDA is trained on padding <eos_token>
),
)
trainer.train()You can launch training job locally with accelerate, or submit it to a Slurm cluster using sbatch.
# Run locally (DeepSpeed ZeRO-2 with 8 GPUs)
accelerate launch \
--config_file scripts/accelerate_configs/deepspeed_zero2.yaml \
examples/llada/sft.py \
--num_train_epochs 4# Submit to a Slurm cluster (DeepSpeed ZeRO-2 with 8 GPUs)
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "deepspeed_zero2" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
# Submit to a Slurm cluster (DeepSpeed ZeRO-2 with 2 nodes, 16 GPUs)
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "deepspeed_zero2" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4See Features & Documentation for training/inference details and task-specific recipes.
-
Support for additional diffusion LLMs.
-
Support for evaluation.
-
Support for RL finetuning.
@misc{dllm,
author = {Zhanhui Zhou and Lingjie Chen},
title = {dLLM: Training Diffusion Large Language Models Made Simple},
howpublished = {https://github.com/ZHZisZZ/dllm},
note = {Accessed: 2025-10-12},
year = {2025}
}

