Skip to content

Better guards for DeepSpeed imports #3351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Union

from accelerate.utils import is_deepspeed_available
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizer

Expand All @@ -30,12 +29,10 @@
AutoModelForSeq2SeqLMWithValueHead,
)

if is_deepspeed_available():
import deepspeed

if TYPE_CHECKING:
from accelerate import Accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel


Expand Down Expand Up @@ -167,6 +164,8 @@ def iter_params(module, recurse=False):

def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
import deepspeed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise doing this import every time we generate is sub-optimal, but DS generation is slow anyway and I don't think it's used much vs vllm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's not sub-optimal. Python's import is effectively conditional—when you import, Python checks if the module is already loaded (via sys.modules) and reuses it if so. So it won't reload the module or cause significant overhead.


if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
Expand Down Expand Up @@ -214,6 +213,8 @@ def unwrap_model_for_generation(
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed

with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
Expand All @@ -222,8 +223,13 @@ def unwrap_model_for_generation(
yield unwrapped_model


def prepare_deepspeed(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.

Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
"""
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252

deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]
Expand Down
41 changes: 3 additions & 38 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union

Expand All @@ -32,7 +31,7 @@
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import is_deepspeed_available, tqdm
from accelerate.utils import tqdm
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
Expand All @@ -56,7 +55,7 @@

from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_joblib_available
from ..models import PreTrainedModelWrapper, create_reference_model
from ..models import create_reference_model, prepare_deepspeed
from .bco_config import BCOConfig
from .utils import (
DPODataCollatorWithPadding,
Expand All @@ -83,9 +82,6 @@
if is_joblib_available():
import joblib

if is_deepspeed_available():
import deepspeed

if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer

Expand Down Expand Up @@ -712,7 +708,7 @@ def make_inputs_require_grad(module, input, output):
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

Expand Down Expand Up @@ -846,37 +842,6 @@ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512

return all_embeddings

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

def _save_optimizer_and_scheduler(self, output_dir):
output_dir = output_dir if output_dir is not None else self.args.output_dir
super()._save_optimizer_and_scheduler(output_dir)
Expand Down
7 changes: 3 additions & 4 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object, is_comet_ml_available, is_deepspeed_available, is_wandb_available
from accelerate.utils import gather_object, is_comet_ml_available, is_wandb_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
Expand All @@ -44,9 +44,6 @@
from .utils import log_table_to_comet_experiment


if is_deepspeed_available():
import deepspeed

if is_comet_ml_available():
pass

Expand Down Expand Up @@ -115,6 +112,8 @@ def _sync_target_model(model, target_model, alpha):
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
import deepspeed

with deepspeed.zero.GatheredParameters(
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
):
Expand Down
41 changes: 3 additions & 38 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union

Expand All @@ -30,7 +29,7 @@
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from accelerate.utils import tqdm
from datasets import Dataset, IterableDataset
from packaging import version
from torch.utils.data import DataLoader
Expand All @@ -53,7 +52,7 @@
from transformers.utils import is_peft_available, is_torch_xpu_available

from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper, create_reference_model
from ..models import create_reference_model, prepare_deepspeed
from ..models.utils import prepare_fsdp
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
Expand All @@ -80,9 +79,6 @@
if is_wandb_available():
import wandb

if is_deepspeed_available():
import deepspeed


@dataclass
class DataCollatorForPreference(DataCollatorMixin):
Expand Down Expand Up @@ -510,7 +506,7 @@ def make_inputs_require_grad(module, input, output):
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
Expand Down Expand Up @@ -676,37 +672,6 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le

return output

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
Expand Down
41 changes: 2 additions & 39 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
import os
import random
import textwrap
from copy import deepcopy
from typing import Any, Callable, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
Expand All @@ -38,7 +36,7 @@
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available

from ..models import PreTrainedModelWrapper
from ..models import prepare_deepspeed
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
from .sft_trainer import SFTTrainer
Expand All @@ -51,10 +49,6 @@
)


if is_deepspeed_available():
import deepspeed


if is_peft_available():
from peft import PeftConfig

Expand Down Expand Up @@ -124,7 +118,7 @@ def __init__(
disable_dropout_in_model(self.model)

if self.is_deepspeed_enabled:
self.teacher_model = self._prepare_deepspeed(teacher_model)
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
else:
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)

Expand Down Expand Up @@ -311,37 +305,6 @@ def training_step(
loss = super().training_step(model, inputs, num_items_in_batch)
return loss

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)

if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)

# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

def create_model_card(
self,
model_name: Optional[str] = None,
Expand Down
12 changes: 7 additions & 5 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_deepspeed_available, is_liger_kernel_available, is_rich_available, is_vllm_available
from ..import_utils import is_liger_kernel_available, is_rich_available, is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
Expand All @@ -61,9 +61,6 @@
)


if is_deepspeed_available():
import deepspeed

if is_peft_available():
from peft import PeftConfig, get_peft_model

Expand Down Expand Up @@ -839,7 +836,12 @@ def _move_model_to_vllm(self):
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
if zero_stage_3:
import deepspeed

gather_if_zero3 = deepspeed.zero.GatheredParameters
else:
gather_if_zero3 = nullcontext

if is_peft_model(self.model):
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging
Expand Down
Loading
Loading