-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Better guards for DeepSpeed imports #3351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -167,6 +163,8 @@ def iter_params(module, recurse=False): | |||
|
|||
def add_hooks(model: "DeepSpeedEngine") -> None: | |||
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model.""" | |||
import deepspeed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realise doing this import every time we generate is sub-optimal, but DS generation is slow anyway and I don't think it's used much vs vllm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually it's not sub-optimal. Python's import is effectively conditional—when you import, Python checks if the module is already loaded (via sys.modules
) and reuses it if so. So it won't reload the module or cause significant overhead.
I'm not sure with the doc builder CI is failing with the installation and Have you seen this before @qgallouedec ? |
it seems like we're having it in all our PR since yesterday |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The
Command to repro:
cc @kashif if you can take a look at the GKD issue when you're back |
The
Command to repro:
Not caused by this PR since the error is also on |
@qgallouedec I've tested all scripts related to the changes in this PR and apart from the ones that were already broken on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect. My only concern is that future me, who will forget about this, might not check the git blame and might move it back to top-level in the course of an innocent PR 😇. Can you add a little comment like this?
def my_func():
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm)
What does this PR do?
Whenever we run
import deepspeed
, CUDA is initialised which is (a) slow and (b) interferes with thevllm
server when DP=8 and TP=1 (see vllm-project/vllm#17079):This PR fixes the issue on the TRL side by:
deepspeed
only when needed_prepare_deepspeed()
in the trainers for the single-use function in the utils.Tested with
To fully solve thePR now merged.vllm
server issue will require an additional refactor on thetransformers
side, but this PR is done to keep things self-contained.Update: transformers PR here huggingface/transformers#37755
TODO
Check the following trainers with Z3:
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.