Skip to content

Guard DeepSpeed imports #37755

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.flash_attention import flash_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.sdpa_attention import sdpa_attention_forward
Expand Down Expand Up @@ -154,9 +154,6 @@
from safetensors.torch import save_file as safe_save_file


if is_deepspeed_available():
import deepspeed

if is_kernels_available():
from kernels import get_kernel

Expand Down Expand Up @@ -2007,6 +2004,8 @@ def _from_config(cls, config, **kwargs):
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
import deepspeed

init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
with ContextManagers(init_contexts):
model = cls(config, **kwargs)
Expand Down Expand Up @@ -2702,6 +2701,8 @@ def resize_token_embeddings(
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am assuming that we do not need an additional is_deepspeed_available() check, but let me know if we do!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we don't need one ! if is_deepspeed_zero3_enabled returns True, deepspeed will be installed

import deepspeed

with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
Expand Down Expand Up @@ -2732,6 +2733,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
new_num_tokens = new_embeddings.weight.shape[0]
else:
Expand Down Expand Up @@ -2820,6 +2823,8 @@ def _get_resized_embeddings(

is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
else:
Expand Down Expand Up @@ -2864,6 +2869,8 @@ def _get_resized_embeddings(

added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
self._init_added_embeddings_weights_with_mean(
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
Expand All @@ -2879,6 +2886,8 @@ def _get_resized_embeddings(
n = min(old_num_tokens, new_num_tokens)

if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
Expand All @@ -2889,6 +2898,8 @@ def _get_resized_embeddings(
# This ensures correct functionality when a Custom Embedding class is passed as input.
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_embeddings.weight, new_embeddings.weight]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
old_embeddings.weight = new_embeddings.weight
Expand Down Expand Up @@ -2941,11 +2952,14 @@ def _get_resized_lm_head(
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
`None`
"""

if new_num_tokens is None:
return old_lm_head

is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
old_num_tokens, old_lm_head_dim = (
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
Expand Down Expand Up @@ -2996,6 +3010,8 @@ def _get_resized_lm_head(

added_num_tokens = new_num_tokens - old_num_tokens
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_lm_head.weight]
if has_new_lm_head_bias:
params += [old_lm_head.bias]
Expand All @@ -3016,6 +3032,8 @@ def _get_resized_lm_head(
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)

if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
self._copy_lm_head_original_to_resized(
Expand Down Expand Up @@ -3762,6 +3780,8 @@ def float(self, *args):
@classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
if is_deepspeed_zero3_enabled():
import deepspeed

init_contexts = [no_init_weights()]
# We cannot initialize the model on meta device with deepspeed when not quantized
if not is_quantized and not _is_ds_init_called:
Expand Down Expand Up @@ -5349,6 +5369,8 @@ def _initialize_missing_keys(
not_initialized_submodules = dict(self.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

not_initialized_parameters = list(
set(
itertools.chain.from_iterable(
Expand Down