|
19 | 19 | import warnings
|
20 | 20 | from collections import defaultdict
|
21 | 21 | from contextlib import contextmanager, nullcontext
|
22 |
| -from copy import deepcopy |
23 | 22 | from dataclasses import dataclass
|
24 | 23 | from typing import Any, Callable, Literal, Optional, Union
|
25 | 24 |
|
|
30 | 29 | import torch.nn.functional as F
|
31 | 30 | import transformers
|
32 | 31 | from accelerate import PartialState
|
33 |
| -from accelerate.utils import is_deepspeed_available, tqdm |
| 32 | +from accelerate.utils import tqdm |
34 | 33 | from datasets import Dataset, IterableDataset
|
35 | 34 | from packaging import version
|
36 | 35 | from torch.utils.data import DataLoader
|
|
53 | 52 | from transformers.utils import is_peft_available, is_torch_xpu_available
|
54 | 53 |
|
55 | 54 | from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
56 |
| -from ..models import PreTrainedModelWrapper, create_reference_model |
| 55 | +from ..models import create_reference_model, prepare_deepspeed |
57 | 56 | from ..models.utils import prepare_fsdp
|
58 | 57 | from .callbacks import SyncRefModelCallback
|
59 | 58 | from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
|
|
80 | 79 | if is_wandb_available():
|
81 | 80 | import wandb
|
82 | 81 |
|
83 |
| -if is_deepspeed_available(): |
84 |
| - import deepspeed |
85 |
| - |
86 | 82 |
|
87 | 83 | @dataclass
|
88 | 84 | class DataCollatorForPreference(DataCollatorMixin):
|
@@ -510,7 +506,7 @@ def make_inputs_require_grad(module, input, output):
|
510 | 506 | )
|
511 | 507 | else:
|
512 | 508 | if self.is_deepspeed_enabled:
|
513 |
| - self.ref_model = self._prepare_deepspeed(self.ref_model) |
| 509 | + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) |
514 | 510 | elif self.is_fsdp_enabled:
|
515 | 511 | self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
516 | 512 | else:
|
@@ -676,37 +672,6 @@ def process_row(features, processing_class, max_prompt_length, max_completion_le
|
676 | 672 |
|
677 | 673 | return output
|
678 | 674 |
|
679 |
| - def _prepare_deepspeed(self, model: PreTrainedModelWrapper): |
680 |
| - # Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 |
681 |
| - deepspeed_plugin = self.accelerator.state.deepspeed_plugin |
682 |
| - config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) |
683 |
| - |
684 |
| - if model is not None: |
685 |
| - if hasattr(model, "config"): |
686 |
| - hidden_size = ( |
687 |
| - max(model.config.hidden_sizes) |
688 |
| - if getattr(model.config, "hidden_sizes", None) |
689 |
| - else getattr(model.config, "hidden_size", None) |
690 |
| - ) |
691 |
| - if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: |
692 |
| - # Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0` |
693 |
| - # This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081 |
694 |
| - config_kwargs.update( |
695 |
| - { |
696 |
| - "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, |
697 |
| - "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, |
698 |
| - "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, |
699 |
| - } |
700 |
| - ) |
701 |
| - |
702 |
| - # If ZeRO-3 is used, we shard both the active and reference model. |
703 |
| - # Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0) |
704 |
| - if config_kwargs["zero_optimization"]["stage"] != 3: |
705 |
| - config_kwargs["zero_optimization"]["stage"] = 0 |
706 |
| - model, *_ = deepspeed.initialize(model=model, config=config_kwargs) |
707 |
| - model.eval() |
708 |
| - return model |
709 |
| - |
710 | 675 | def _set_signature_columns_if_needed(self):
|
711 | 676 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
712 | 677 | # By default, this method sets `self._signature_columns` to the model's expected inputs.
|
|
0 commit comments