Skip to content

Commit acdbe62

Browse files
lewtunSunMarc
andauthored
Guard DeepSpeed imports (#37755)
* Guard DeepSpeed imports * Fix import * Import deepspeed consistently --------- Co-authored-by: Marc Sun <[email protected]>
1 parent af6d275 commit acdbe62

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

src/transformers/modeling_utils.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from .generation import CompileConfig, GenerationConfig
5858
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
5959
from .integrations.accelerate import find_tied_parameters, init_empty_weights
60-
from .integrations.deepspeed import _load_state_dict_into_zero3_model, is_deepspeed_available
60+
from .integrations.deepspeed import _load_state_dict_into_zero3_model
6161
from .integrations.flash_attention import flash_attention_forward
6262
from .integrations.flex_attention import flex_attention_forward
6363
from .integrations.sdpa_attention import sdpa_attention_forward
@@ -154,9 +154,6 @@
154154
from safetensors.torch import save_file as safe_save_file
155155

156156

157-
if is_deepspeed_available():
158-
import deepspeed
159-
160157
if is_kernels_available():
161158
from kernels import get_kernel
162159

@@ -2007,6 +2004,8 @@ def _from_config(cls, config, **kwargs):
20072004
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
20082005
# this immediately partitions the model across all gpus, to avoid the overhead in time
20092006
# and memory copying it on CPU or each GPU first
2007+
import deepspeed
2008+
20102009
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
20112010
with ContextManagers(init_contexts):
20122011
model = cls(config, **kwargs)
@@ -2702,6 +2701,8 @@ def resize_token_embeddings(
27022701
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
27032702
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
27042703
if is_deepspeed_zero3_enabled() and not is_quantized:
2704+
import deepspeed
2705+
27052706
with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
27062707
vocab_size = model_embeds.weight.shape[0]
27072708
else:
@@ -2732,6 +2733,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
27322733
# Update new_num_tokens with the actual size of new_embeddings
27332734
if pad_to_multiple_of is not None:
27342735
if is_deepspeed_zero3_enabled() and not is_quantized:
2736+
import deepspeed
2737+
27352738
with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
27362739
new_num_tokens = new_embeddings.weight.shape[0]
27372740
else:
@@ -2820,6 +2823,8 @@ def _get_resized_embeddings(
28202823

28212824
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
28222825
if is_deepspeed_zero3_enabled() and not is_quantized:
2826+
import deepspeed
2827+
28232828
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
28242829
old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
28252830
else:
@@ -2864,6 +2869,8 @@ def _get_resized_embeddings(
28642869

28652870
added_num_tokens = new_num_tokens - old_num_tokens
28662871
if is_deepspeed_zero3_enabled() and not is_quantized:
2872+
import deepspeed
2873+
28672874
with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
28682875
self._init_added_embeddings_weights_with_mean(
28692876
old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
@@ -2879,6 +2886,8 @@ def _get_resized_embeddings(
28792886
n = min(old_num_tokens, new_num_tokens)
28802887

28812888
if is_deepspeed_zero3_enabled() and not is_quantized:
2889+
import deepspeed
2890+
28822891
params = [old_embeddings.weight, new_embeddings.weight]
28832892
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
28842893
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
@@ -2889,6 +2898,8 @@ def _get_resized_embeddings(
28892898
# This ensures correct functionality when a Custom Embedding class is passed as input.
28902899
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
28912900
if is_deepspeed_zero3_enabled() and not is_quantized:
2901+
import deepspeed
2902+
28922903
params = [old_embeddings.weight, new_embeddings.weight]
28932904
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
28942905
old_embeddings.weight = new_embeddings.weight
@@ -2941,11 +2952,14 @@ def _get_resized_lm_head(
29412952
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
29422953
`None`
29432954
"""
2955+
29442956
if new_num_tokens is None:
29452957
return old_lm_head
29462958

29472959
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
29482960
if is_deepspeed_zero3_enabled() and not is_quantized:
2961+
import deepspeed
2962+
29492963
with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
29502964
old_num_tokens, old_lm_head_dim = (
29512965
old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
@@ -2996,6 +3010,8 @@ def _get_resized_lm_head(
29963010

29973011
added_num_tokens = new_num_tokens - old_num_tokens
29983012
if is_deepspeed_zero3_enabled() and not is_quantized:
3013+
import deepspeed
3014+
29993015
params = [old_lm_head.weight]
30003016
if has_new_lm_head_bias:
30013017
params += [old_lm_head.bias]
@@ -3016,6 +3032,8 @@ def _get_resized_lm_head(
30163032
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
30173033

30183034
if is_deepspeed_zero3_enabled() and not is_quantized:
3035+
import deepspeed
3036+
30193037
params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
30203038
with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
30213039
self._copy_lm_head_original_to_resized(
@@ -3762,6 +3780,8 @@ def float(self, *args):
37623780
@classmethod
37633781
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
37643782
if is_deepspeed_zero3_enabled():
3783+
import deepspeed
3784+
37653785
init_contexts = [no_init_weights()]
37663786
# We cannot initialize the model on meta device with deepspeed when not quantized
37673787
if not is_quantized and not _is_ds_init_called:
@@ -5349,6 +5369,8 @@ def _initialize_missing_keys(
53495369
not_initialized_submodules = dict(self.named_modules())
53505370
# This will only initialize submodules that are not marked as initialized by the line above.
53515371
if is_deepspeed_zero3_enabled() and not is_quantized:
5372+
import deepspeed
5373+
53525374
not_initialized_parameters = list(
53535375
set(
53545376
itertools.chain.from_iterable(

0 commit comments

Comments
 (0)