57
57
from .generation import CompileConfig , GenerationConfig
58
58
from .integrations import PeftAdapterMixin , deepspeed_config , is_deepspeed_zero3_enabled
59
59
from .integrations .accelerate import find_tied_parameters , init_empty_weights
60
- from .integrations .deepspeed import _load_state_dict_into_zero3_model , is_deepspeed_available
60
+ from .integrations .deepspeed import _load_state_dict_into_zero3_model
61
61
from .integrations .flash_attention import flash_attention_forward
62
62
from .integrations .flex_attention import flex_attention_forward
63
63
from .integrations .sdpa_attention import sdpa_attention_forward
154
154
from safetensors .torch import save_file as safe_save_file
155
155
156
156
157
- if is_deepspeed_available ():
158
- import deepspeed
159
-
160
157
if is_kernels_available ():
161
158
from kernels import get_kernel
162
159
@@ -2007,6 +2004,8 @@ def _from_config(cls, config, **kwargs):
2007
2004
logger .info ("Detected DeepSpeed ZeRO-3: activating zero.init() for this model" )
2008
2005
# this immediately partitions the model across all gpus, to avoid the overhead in time
2009
2006
# and memory copying it on CPU or each GPU first
2007
+ import deepspeed
2008
+
2010
2009
init_contexts = [deepspeed .zero .Init (config_dict_or_path = deepspeed_config ()), set_zero3_state ()]
2011
2010
with ContextManagers (init_contexts ):
2012
2011
model = cls (config , ** kwargs )
@@ -2702,6 +2701,8 @@ def resize_token_embeddings(
2702
2701
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
2703
2702
is_quantized = hasattr (self , "hf_quantizer" ) and self .hf_quantizer is not None
2704
2703
if is_deepspeed_zero3_enabled () and not is_quantized :
2704
+ import deepspeed
2705
+
2705
2706
with deepspeed .zero .GatheredParameters (model_embeds .weight , modifier_rank = None ):
2706
2707
vocab_size = model_embeds .weight .shape [0 ]
2707
2708
else :
@@ -2732,6 +2733,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
2732
2733
# Update new_num_tokens with the actual size of new_embeddings
2733
2734
if pad_to_multiple_of is not None :
2734
2735
if is_deepspeed_zero3_enabled () and not is_quantized :
2736
+ import deepspeed
2737
+
2735
2738
with deepspeed .zero .GatheredParameters (new_embeddings .weight , modifier_rank = None ):
2736
2739
new_num_tokens = new_embeddings .weight .shape [0 ]
2737
2740
else :
@@ -2820,6 +2823,8 @@ def _get_resized_embeddings(
2820
2823
2821
2824
is_quantized = hasattr (self , "hf_quantizer" ) and self .hf_quantizer is not None
2822
2825
if is_deepspeed_zero3_enabled () and not is_quantized :
2826
+ import deepspeed
2827
+
2823
2828
with deepspeed .zero .GatheredParameters (old_embeddings .weight , modifier_rank = None ):
2824
2829
old_num_tokens , old_embedding_dim = old_embeddings .weight .size ()
2825
2830
else :
@@ -2864,6 +2869,8 @@ def _get_resized_embeddings(
2864
2869
2865
2870
added_num_tokens = new_num_tokens - old_num_tokens
2866
2871
if is_deepspeed_zero3_enabled () and not is_quantized :
2872
+ import deepspeed
2873
+
2867
2874
with deepspeed .zero .GatheredParameters ([old_embeddings .weight ], modifier_rank = None ):
2868
2875
self ._init_added_embeddings_weights_with_mean (
2869
2876
old_embeddings , new_embeddings , old_embedding_dim , old_num_tokens , added_num_tokens
@@ -2879,6 +2886,8 @@ def _get_resized_embeddings(
2879
2886
n = min (old_num_tokens , new_num_tokens )
2880
2887
2881
2888
if is_deepspeed_zero3_enabled () and not is_quantized :
2889
+ import deepspeed
2890
+
2882
2891
params = [old_embeddings .weight , new_embeddings .weight ]
2883
2892
with deepspeed .zero .GatheredParameters (params , modifier_rank = 0 ):
2884
2893
new_embeddings .weight .data [:n , :] = old_embeddings .weight .data [:n , :]
@@ -2889,6 +2898,8 @@ def _get_resized_embeddings(
2889
2898
# This ensures correct functionality when a Custom Embedding class is passed as input.
2890
2899
# The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
2891
2900
if is_deepspeed_zero3_enabled () and not is_quantized :
2901
+ import deepspeed
2902
+
2892
2903
params = [old_embeddings .weight , new_embeddings .weight ]
2893
2904
with deepspeed .zero .GatheredParameters (params , modifier_rank = 0 ):
2894
2905
old_embeddings .weight = new_embeddings .weight
@@ -2941,11 +2952,14 @@ def _get_resized_lm_head(
2941
2952
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
2942
2953
`None`
2943
2954
"""
2955
+
2944
2956
if new_num_tokens is None :
2945
2957
return old_lm_head
2946
2958
2947
2959
is_quantized = hasattr (self , "hf_quantizer" ) and self .hf_quantizer is not None
2948
2960
if is_deepspeed_zero3_enabled () and not is_quantized :
2961
+ import deepspeed
2962
+
2949
2963
with deepspeed .zero .GatheredParameters (old_lm_head .weight , modifier_rank = None ):
2950
2964
old_num_tokens , old_lm_head_dim = (
2951
2965
old_lm_head .weight .size () if not transposed else old_lm_head .weight .t ().size ()
@@ -2996,6 +3010,8 @@ def _get_resized_lm_head(
2996
3010
2997
3011
added_num_tokens = new_num_tokens - old_num_tokens
2998
3012
if is_deepspeed_zero3_enabled () and not is_quantized :
3013
+ import deepspeed
3014
+
2999
3015
params = [old_lm_head .weight ]
3000
3016
if has_new_lm_head_bias :
3001
3017
params += [old_lm_head .bias ]
@@ -3016,6 +3032,8 @@ def _get_resized_lm_head(
3016
3032
num_tokens_to_copy = min (old_num_tokens , new_num_tokens )
3017
3033
3018
3034
if is_deepspeed_zero3_enabled () and not is_quantized :
3035
+ import deepspeed
3036
+
3019
3037
params = [old_lm_head .weight , old_lm_head .bias , new_lm_head .weight , new_lm_head .bias ]
3020
3038
with deepspeed .zero .GatheredParameters (params , modifier_rank = 0 ):
3021
3039
self ._copy_lm_head_original_to_resized (
@@ -3762,6 +3780,8 @@ def float(self, *args):
3762
3780
@classmethod
3763
3781
def get_init_context (cls , is_quantized : bool , _is_ds_init_called : bool ):
3764
3782
if is_deepspeed_zero3_enabled ():
3783
+ import deepspeed
3784
+
3765
3785
init_contexts = [no_init_weights ()]
3766
3786
# We cannot initialize the model on meta device with deepspeed when not quantized
3767
3787
if not is_quantized and not _is_ds_init_called :
@@ -5349,6 +5369,8 @@ def _initialize_missing_keys(
5349
5369
not_initialized_submodules = dict (self .named_modules ())
5350
5370
# This will only initialize submodules that are not marked as initialized by the line above.
5351
5371
if is_deepspeed_zero3_enabled () and not is_quantized :
5372
+ import deepspeed
5373
+
5352
5374
not_initialized_parameters = list (
5353
5375
set (
5354
5376
itertools .chain .from_iterable (
0 commit comments