Skip to content

Commit 1bca495

Browse files
authored
Better guards for DeepSpeed imports (#3351)
1 parent 39e9639 commit 1bca495

File tree

8 files changed

+33
-205
lines changed

8 files changed

+33
-205
lines changed

trl/models/utils.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from dataclasses import dataclass
1919
from typing import TYPE_CHECKING, Literal, Optional, Union
2020

21-
from accelerate.utils import is_deepspeed_available
2221
from packaging import version
2322
from transformers import PreTrainedModel, PreTrainedTokenizer
2423

@@ -30,12 +29,10 @@
3029
AutoModelForSeq2SeqLMWithValueHead,
3130
)
3231

33-
if is_deepspeed_available():
34-
import deepspeed
35-
3632
if TYPE_CHECKING:
3733
from accelerate import Accelerator
3834
from deepspeed.runtime.engine import DeepSpeedEngine
35+
from torch.nn import Module
3936
from torch.nn.parallel.distributed import DistributedDataParallel
4037

4138

@@ -167,6 +164,8 @@ def iter_params(module, recurse=False):
167164

168165
def add_hooks(model: "DeepSpeedEngine") -> None:
169166
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
167+
import deepspeed
168+
170169
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
171170
return
172171
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
@@ -214,6 +213,8 @@ def unwrap_model_for_generation(
214213
if not gather_deepspeed3_params:
215214
yield accelerator.unwrap_model(model)
216215
else:
216+
import deepspeed
217+
217218
with deepspeed.zero.GatheredParameters(model.parameters()):
218219
remove_hooks(model)
219220
yield accelerator.unwrap_model(model)
@@ -222,8 +223,13 @@ def unwrap_model_for_generation(
222223
yield unwrapped_model
223224

224225

225-
def prepare_deepspeed(model, accelerator):
226-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
226+
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
227+
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.
228+
229+
Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
230+
"""
231+
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252
232+
227233
deepspeed_plugin = accelerator.state.deepspeed_plugin
228234
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
229235
stage = config_kwargs["zero_optimization"]["stage"]

trl/trainer/bco_trainer.py

+3-38
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import warnings
2020
from collections import defaultdict
2121
from contextlib import contextmanager, nullcontext
22-
from copy import deepcopy
2322
from operator import itemgetter
2423
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
2524

@@ -32,7 +31,7 @@
3231
import transformers
3332
from accelerate import PartialState
3433
from accelerate.logging import get_logger
35-
from accelerate.utils import is_deepspeed_available, tqdm
34+
from accelerate.utils import tqdm
3635
from datasets import Dataset
3736
from packaging import version
3837
from torch.utils.data import DataLoader, SequentialSampler
@@ -56,7 +55,7 @@
5655

5756
from ..data_utils import maybe_apply_chat_template
5857
from ..import_utils import is_joblib_available
59-
from ..models import PreTrainedModelWrapper, create_reference_model
58+
from ..models import create_reference_model, prepare_deepspeed
6059
from .bco_config import BCOConfig
6160
from .utils import (
6261
DPODataCollatorWithPadding,
@@ -83,9 +82,6 @@
8382
if is_joblib_available():
8483
import joblib
8584

86-
if is_deepspeed_available():
87-
import deepspeed
88-
8985
if TYPE_CHECKING:
9086
from transformers import PreTrainedModel, PreTrainedTokenizer
9187

@@ -712,7 +708,7 @@ def make_inputs_require_grad(module, input, output):
712708
)
713709
else:
714710
if self.is_deepspeed_enabled:
715-
self.ref_model = self._prepare_deepspeed(self.ref_model)
711+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
716712
else:
717713
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
718714

@@ -846,37 +842,6 @@ def _get_sample_prompt_embeddings(self, dataset: Dataset, sample_size: int = 512
846842

847843
return all_embeddings
848844

849-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
850-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
851-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
852-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
853-
854-
if model is not None:
855-
if hasattr(model, "config"):
856-
hidden_size = (
857-
max(model.config.hidden_sizes)
858-
if getattr(model.config, "hidden_sizes", None)
859-
else getattr(model.config, "hidden_size", None)
860-
)
861-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
862-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
863-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
864-
config_kwargs.update(
865-
{
866-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
867-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
868-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
869-
}
870-
)
871-
872-
# If ZeRO-3 is used, we shard both the active and reference model.
873-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
874-
if config_kwargs["zero_optimization"]["stage"] != 3:
875-
config_kwargs["zero_optimization"]["stage"] = 0
876-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
877-
model.eval()
878-
return model
879-
880845
def _save_optimizer_and_scheduler(self, output_dir):
881846
output_dir = output_dir if output_dir is not None else self.args.output_dir
882847
super()._save_optimizer_and_scheduler(output_dir)

trl/trainer/callbacks.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch
2020
from accelerate import Accelerator
2121
from accelerate.state import AcceleratorState
22-
from accelerate.utils import gather_object, is_comet_ml_available, is_deepspeed_available, is_wandb_available
22+
from accelerate.utils import gather_object, is_comet_ml_available, is_wandb_available
2323
from rich.console import Console, Group
2424
from rich.live import Live
2525
from rich.panel import Panel
@@ -44,9 +44,6 @@
4444
from .utils import log_table_to_comet_experiment
4545

4646

47-
if is_deepspeed_available():
48-
import deepspeed
49-
5047
if is_comet_ml_available():
5148
pass
5249

@@ -115,6 +112,8 @@ def _sync_target_model(model, target_model, alpha):
115112
def sync_target_model(model, target_model, alpha):
116113
deepspeed_plugin = AcceleratorState().deepspeed_plugin
117114
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
115+
import deepspeed
116+
118117
with deepspeed.zero.GatheredParameters(
119118
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
120119
):

trl/trainer/dpo_trainer.py

+3-38
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import warnings
2020
from collections import defaultdict
2121
from contextlib import contextmanager, nullcontext
22-
from copy import deepcopy
2322
from dataclasses import dataclass
2423
from typing import Any, Callable, Literal, Optional, Union
2524

@@ -30,7 +29,7 @@
3029
import torch.nn.functional as F
3130
import transformers
3231
from accelerate import PartialState
33-
from accelerate.utils import is_deepspeed_available, tqdm
32+
from accelerate.utils import tqdm
3433
from datasets import Dataset, IterableDataset
3534
from packaging import version
3635
from torch.utils.data import DataLoader
@@ -53,7 +52,7 @@
5352
from transformers.utils import is_peft_available, is_torch_xpu_available
5453

5554
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
56-
from ..models import PreTrainedModelWrapper, create_reference_model
55+
from ..models import create_reference_model, prepare_deepspeed
5756
from ..models.utils import prepare_fsdp
5857
from .callbacks import SyncRefModelCallback
5958
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
@@ -80,9 +79,6 @@
8079
if is_wandb_available():
8180
import wandb
8281

83-
if is_deepspeed_available():
84-
import deepspeed
85-
8682

8783
@dataclass
8884
class DataCollatorForPreference(DataCollatorMixin):
@@ -510,7 +506,7 @@ def make_inputs_require_grad(module, input, output):
510506
)
511507
else:
512508
if self.is_deepspeed_enabled:
513-
self.ref_model = self._prepare_deepspeed(self.ref_model)
509+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
514510
elif self.is_fsdp_enabled:
515511
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
516512
else:
@@ -676,37 +672,6 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le
676672

677673
return output
678674

679-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
680-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
681-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
682-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
683-
684-
if model is not None:
685-
if hasattr(model, "config"):
686-
hidden_size = (
687-
max(model.config.hidden_sizes)
688-
if getattr(model.config, "hidden_sizes", None)
689-
else getattr(model.config, "hidden_size", None)
690-
)
691-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
692-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
693-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
694-
config_kwargs.update(
695-
{
696-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
697-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
698-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
699-
}
700-
)
701-
702-
# If ZeRO-3 is used, we shard both the active and reference model.
703-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
704-
if config_kwargs["zero_optimization"]["stage"] != 3:
705-
config_kwargs["zero_optimization"]["stage"] = 0
706-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
707-
model.eval()
708-
return model
709-
710675
def _set_signature_columns_if_needed(self):
711676
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
712677
# By default, this method sets `self._signature_columns` to the model's expected inputs.

trl/trainer/gkd_trainer.py

+2-39
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
import os
1616
import random
1717
import textwrap
18-
from copy import deepcopy
1918
from typing import Any, Callable, Optional, Union
2019

2120
import torch
2221
import torch.nn as nn
2322
import torch.nn.functional as F
24-
from accelerate.utils import is_deepspeed_available
2523
from datasets import Dataset
2624
from transformers import (
2725
AutoModelForCausalLM,
@@ -38,7 +36,7 @@
3836
from transformers.trainer_utils import EvalPrediction
3937
from transformers.utils import is_peft_available
4038

41-
from ..models import PreTrainedModelWrapper
39+
from ..models import prepare_deepspeed
4240
from ..models.utils import unwrap_model_for_generation
4341
from .gkd_config import GKDConfig
4442
from .sft_trainer import SFTTrainer
@@ -51,10 +49,6 @@
5149
)
5250

5351

54-
if is_deepspeed_available():
55-
import deepspeed
56-
57-
5852
if is_peft_available():
5953
from peft import PeftConfig
6054

@@ -124,7 +118,7 @@ def __init__(
124118
disable_dropout_in_model(self.model)
125119

126120
if self.is_deepspeed_enabled:
127-
self.teacher_model = self._prepare_deepspeed(teacher_model)
121+
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
128122
else:
129123
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
130124

@@ -311,37 +305,6 @@ def training_step(
311305
loss = super().training_step(model, inputs, num_items_in_batch)
312306
return loss
313307

314-
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
315-
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
316-
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
317-
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
318-
319-
if model is not None:
320-
if hasattr(model, "config"):
321-
hidden_size = (
322-
max(model.config.hidden_sizes)
323-
if getattr(model.config, "hidden_sizes", None)
324-
else getattr(model.config, "hidden_size", None)
325-
)
326-
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
327-
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
328-
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
329-
config_kwargs.update(
330-
{
331-
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
332-
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
333-
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
334-
}
335-
)
336-
337-
# If ZeRO-3 is used, we shard both the active and reference model.
338-
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
339-
if config_kwargs["zero_optimization"]["stage"] != 3:
340-
config_kwargs["zero_optimization"]["stage"] = 0
341-
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
342-
model.eval()
343-
return model
344-
345308
def create_model_card(
346309
self,
347310
model_name: Optional[str] = None,

trl/trainer/grpo_trainer.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
4848
from ..extras.profiling import profiling_context, profiling_decorator
4949
from ..extras.vllm_client import VLLMClient
50-
from ..import_utils import is_deepspeed_available, is_liger_kernel_available, is_rich_available, is_vllm_available
50+
from ..import_utils import is_liger_kernel_available, is_rich_available, is_vllm_available
5151
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
5252
from .callbacks import SyncRefModelCallback
5353
from .grpo_config import GRPOConfig
@@ -61,9 +61,6 @@
6161
)
6262

6363

64-
if is_deepspeed_available():
65-
import deepspeed
66-
6764
if is_peft_available():
6865
from peft import PeftConfig, get_peft_model
6966

@@ -839,7 +836,12 @@ def _move_model_to_vllm(self):
839836
# For DeepSpeed ZeRO-3, we need to gather all parameters before operations
840837
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
841838
zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
842-
gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
839+
if zero_stage_3:
840+
import deepspeed
841+
842+
gather_if_zero3 = deepspeed.zero.GatheredParameters
843+
else:
844+
gather_if_zero3 = nullcontext
843845

844846
if is_peft_model(self.model):
845847
# With PEFT and DeepSpeed ZeRO Stage 3, we must gather the full model at once before merging, as merging

0 commit comments

Comments
 (0)