Skip to content

Better guards for DeepSpeed imports #3351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 26, 2025
Merged

Better guards for DeepSpeed imports #3351

merged 7 commits into from
Apr 26, 2025

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Apr 24, 2025

What does this PR do?

Whenever we run import deepspeed, CUDA is initialised which is (a) slow and (b) interferes with the vllm server when DP=8 and TP=1 (see vllm-project/vllm#17079):

ERROR:    Traceback (most recent call last):
  File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/starlette/routing.py", line 692, in lifespan
    async with self.lifespan_context(app) as maybe_state:
  File "/admin/home/lewis/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/git/hf/trl/trl/scripts/vllm_serve.py", line 362, in lifespan
    msg = connection.recv()
          ^^^^^^^^^^^^^^^^^
  File "/admin/home/lewis/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
          ^^^^^^^^^^^^^^^^^^
  File "/admin/home/lewis/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/admin/home/lewis/.local/share/uv/python/cpython-3.11.11-linux-x86_64-gnu/lib/python3.11/multiprocessing/connection.py", line 399, in _recv
    raise EOFError
EOFError

ERROR:    Application startup failed. Exiting.

This PR fixes the issue on the TRL side by:

  • Importing deepspeed only when needed
  • Replacing the many instances of _prepare_deepspeed() in the trainers for the single-use function in the utils.

Tested with

# install transformers@main
pip install git+https://github.com/huggingface/transformers.git

# launch the server
trl vllm-serve --model Qwen/Qwen2.5-0.5B --revision main --tensor_parallel_size 1 --data_parallel_size 8

To fully solve the vllm server issue will require an additional refactor on the transformers side, but this PR is done to keep things self-contained. PR now merged.

Update: transformers PR here huggingface/transformers#37755

TODO

Check the following trainers with Z3:

  • kto
  • dpo
  • bco
  • gkd
  • grpo
  • orpo

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@lewtun lewtun requested a review from qgallouedec April 24, 2025 12:05
@@ -167,6 +163,8 @@ def iter_params(module, recurse=False):

def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
import deepspeed
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise doing this import every time we generate is sub-optimal, but DS generation is slow anyway and I don't think it's used much vs vllm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's not sub-optimal. Python's import is effectively conditional—when you import, Python checks if the module is already loaded (via sys.modules) and reuses it if so. So it won't reload the module or cause significant overhead.

@lewtun
Copy link
Member Author

lewtun commented Apr 24, 2025

I'm not sure with the doc builder CI is failing with the installation and ERROR: Could not install packages due to an OSError: [Errno 28] No space left on device

Have you seen this before @qgallouedec ?

@qgallouedec
Copy link
Member

Have you seen this before @qgallouedec ?

it seems like we're having it in all our PR since yesterday

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@lewtun
Copy link
Member Author

lewtun commented Apr 25, 2025

The GKDTrainer is broken with Z3, but not because of this PR (same result on trl@main):

rank5]: Traceback (most recent call last):
[rank5]:   File "/fsx/lewis/git/hf/trl/examples/scripts/gkd.py", line 137, in <module>
[rank5]:     trainer.train()
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/trainer.py", line 2229, in train
[rank5]:     return inner_training_loop(
[rank5]:            ^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/trainer.py", line 2553, in _inner_training_loop
[rank5]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank5]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl/trainer/gkd_trainer.py", line 304, in training_step
[rank5]:     new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
[rank5]:                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl/trainer/gkd_trainer.py", line 264, in generate_on_policy_outputs
[rank5]:     generated_outputs = model.generate(
[rank5]:                         ^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank5]:     return func(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/generation/utils.py", line 2490, in generate
[rank5]:     result = self._sample(
[rank5]:              ^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/generation/utils.py", line 3453, in _sample
[rank5]:     outputs = model_forward(**model_inputs, return_dict=True)
[rank5]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank5]:     return inner()
[rank5]:            ^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1793, in inner
[rank5]:     result = forward_call(*args, **kwargs)
[rank5]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/utils/generic.py", line 969, in wrapper
[rank5]:     output = func(self, *args, **kwargs)
[rank5]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 823, in forward
[rank5]:     outputs: BaseModelOutputWithPast = self.model(
[rank5]:                                        ^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank5]:     return self._call_impl(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank5]:     return forward_call(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/utils/generic.py", line 969, in wrapper
[rank5]:     output = func(self, *args, **kwargs)
[rank5]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 531, in forward
[rank5]:     causal_mask = self._update_causal_mask(
[rank5]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 640, in _update_causal_mask
[rank5]:     causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
[rank5]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 716, in _prepare_4d_causal_attention_mask_with_cache_position
[rank5]:     causal_mask *= diagonal_attend_mask
[rank5]: RuntimeError: The size of tensor a (441) must match the size of tensor b (442) at non-singleton dimension 0

Command to repro:

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/gkd.py     --model_name_or_path Qwen/Qwen2-0.5B-Instruct     --teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct     --dataset_name trl-lib/chatbot_arena_completions     --learning_rate 2e-5     --per_device_train_batch_size 4     --gradient_accumulation_steps 8     --output_dir gkd-model     --logging_steps 10     --num_train_epochs 1     --push_to_hub     --gradient_checkpointing --bf16

cc @kashif if you can take a look at the GKD issue when you're back

@lewtun
Copy link
Member Author

lewtun commented Apr 25, 2025

The BCOTrainer fails with:

[rank3]: Traceback (most recent call last):
[rank3]:   File "/fsx/lewis/git/hf/trl/examples/scripts/bco.py", line 111, in <module>
[rank3]:     accelerator = Accelerator()
[rank3]:                   ^^^^^^^^^^^^^
[rank3]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/accelerator.py", line 322, in __init__
[rank3]:     deepspeed_plugins = AcceleratorState().deepspeed_plugins
[rank3]:                         ^^^^^^^^^^^^^^^^^^
[rank3]:   File "/fsx/lewis/git/hf/trl/trl-env/lib/python3.11/site-packages/accelerate/state.py", line 935, in __init__
[rank3]:     raise ValueError(
[rank3]: ValueError: Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` before using any functionality from the `accelerate` library.

Command to repro:

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/bco.py     --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct     --trust_remote_code     --dataset_name trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness     --per_device_train_batch_size 16     --per_device_eval_batch_size 32     --num_train_epochs 1     --learning_rate 1e-6     --gradient_checkpointing     --gradient_accumulation_steps 1     --logging_steps 0.01     --eval_steps 0.2     --save_strategy no     --output_dir=bco-aligned-model     --logging_first_step     --max_length 2048     --max_prompt_length 1536     --max_completion_length 1024     --no_remove_unused_columns     --warmup_ratio 0.1     --bf16     --report_to wandb

Not caused by this PR since the error is also on trl@main. cc @kashif if you can also take a look 🙏

@lewtun
Copy link
Member Author

lewtun commented Apr 25, 2025

@qgallouedec I've tested all scripts related to the changes in this PR and apart from the ones that were already broken on main, the rest all look good. Would you mind taking another look to see if you're happy to merge?

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect. My only concern is that future me, who will forget about this, might not check the git blame and might move it back to top-level in the course of an innocent PR 😇. Can you add a little comment like this?

def my_func():
    import deepspeed  # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm)

@lewtun lewtun merged commit 1bca495 into main Apr 26, 2025
10 checks passed
@lewtun lewtun deleted the guard-deepspeed-imports branch April 26, 2025 08:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants