Skip to content

Commit ca4d407

Browse files
committed
[megatron] support export lora to_mcore (#5445)
1 parent 247733b commit ca4d407

File tree

6 files changed

+69
-26
lines changed

6 files changed

+69
-26
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,16 @@ swift export \
205205
```
206206
- 注意:`mcore_adapters`文件夹中包含`args.json`文件,转换过程中会读取文件中`mcore_model`和LoRA相关的参数信息,并将`mcore_model``mcore_adapters`进行merge-lora成完整权重,最终转换成HF格式权重。
207207

208+
如果你只想merge-lora,而不希望转成HF格式权重,用于后续DPO训练,可以使用以下脚本:
209+
```shell
210+
CUDA_VISIBLE_DEVICES=0 \
211+
swift export \
212+
--mcore_adapters megatron_output/Qwen2.5-7B-Instruct/vx-xxx \
213+
--to_mcore true \
214+
--torch_dtype bfloat16 \
215+
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-mcore \
216+
--test_convert_precision true
217+
```
208218

209219
## Benchmark
210220

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,18 @@ swift export \
213213

214214
- Note: The `mcore_adapters` folder contains an `args.json` file. During the conversion process, parameters related to `mcore_model` and LoRA will be loaded from this file. The system will then perform a merge-lora operation between the `mcore_model` and `mcore_adapters` to obtain the complete model weights, and finally convert them into HuggingFace (HF) format.
215215

216+
If you only want to merge the LoRA weights without converting them to Hugging Face format, for subsequent DPO training, you can use the following script:
217+
218+
```shell
219+
CUDA_VISIBLE_DEVICES=0 \
220+
swift export \
221+
--mcore_adapters megatron_output/Qwen2.5-7B-Instruct/vx-xxx \
222+
--to_mcore true \
223+
--torch_dtype bfloat16 \
224+
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-mcore \
225+
--test_convert_precision true
226+
```
227+
216228
## Benchmark
217229
The speed comparison of full-parameter training for Dense/MoE models using `megatron sft` and `swift sft` on a single machine with eight A800 GPUs is shown below. The corresponding scripts can be found [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark).
218230

swift/llm/export/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ def run(self):
3232
export_to_ollama(args)
3333
elif args.to_cached_dataset:
3434
export_cached_dataset(args)
35+
elif args.to_hf or args.mcore_adapters and args.to_mcore:
36+
from swift.megatron import convert_mcore2hf
37+
convert_mcore2hf(args)
3538
elif args.to_mcore:
3639
from swift.megatron import convert_hf2mcore
3740
convert_hf2mcore(args)
38-
elif args.to_hf:
39-
from swift.megatron import convert_mcore2hf
40-
convert_mcore2hf(args)
4141
elif args.push_to_hub:
4242
model_dir = args.adapters and args.adapters[0] or args.model_dir
4343
assert model_dir, f'model_dir: {model_dir}'

swift/llm/infer/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def prepare_adapter(args, model, adapters=None):
143143

144144
def prepare_model_template(args, **kwargs):
145145
model, processor = args.get_model_processor(**kwargs)
146-
model = prepare_adapter(args, model)
147146
template = args.get_template(processor)
148-
update_generation_config_eos_token(model.generation_config, template)
147+
if model is not None:
148+
model = prepare_adapter(args, model)
149+
update_generation_config_eos_token(model.generation_config, template)
149150
return model, template

swift/megatron/init.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,9 +790,20 @@ def _worker(plan_shard):
790790
FileSystemReader.read_data = read_data
791791

792792

793+
def _patch_TELinear():
794+
from megatron.core.extensions.transformer_engine import TELinear
795+
796+
def __repr__(self):
797+
return (f'{type(self).__name__}(in_features={self.in_features}, '
798+
f'out_features={self.out_features}, bias={self.use_bias}, TP={self.tp_size})')
799+
800+
TELinear.__repr__ = __repr__
801+
802+
793803
def _patch_megatron():
794804
_patch_flash_attn()
795805
_patch_transformer_engine()
806+
_patch_TELinear()
796807
_patch__batched_p2p_ops()
797808
_patch_mla_attention()
798809
_patch_TEGroupedLinear()

swift/megatron/utils/convert.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
169169
logger.info(f'megatron_config: {kwargs}')
170170
_check_megatron_kwargs(kwargs)
171171
current_convert_kwargs = convert_kwargs.copy()
172-
if hf_model.model_info.is_moe_model:
172+
if args.model_info.is_moe_model:
173173
current_convert_kwargs['moe_grouped_gemm'] = True
174174
megatron_args = MegatronArguments(
175175
**kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
@@ -183,6 +183,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
183183
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
184184
if args.test_convert_precision:
185185
test_convert_precision(hf_model, mg_model, template)
186+
del hf_model
186187
logger.info('Successfully transferred HF model weights to MG model.')
187188
args.save_args()
188189
mg_save_checkpoint(1, [mg_model], None, None, 0)
@@ -191,25 +192,22 @@ def convert_hf2mcore(args: ExportArguments) -> None:
191192

192193
def convert_mcore2hf(args: ExportArguments) -> None:
193194
from swift.megatron import prepare_mcore_model, adapter_state_dict_context
194-
hf_model, template = prepare_model_template(args)
195+
hf_model, template = prepare_model_template(args, load_model=args.to_hf)
195196
processor = template.processor
196-
if args.thread_count is None:
197-
checkpoint_size = sum(get_n_params_grads(hf_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
198-
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
199-
patch_torch_dist_shard(args.thread_count)
200197

201198
megatron_model_meta = get_megatron_model_meta(args.model_type)
202199
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
203200
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
204201
logger.info(f'megatron_config: {kwargs}')
205202
_check_megatron_kwargs(kwargs)
206203
current_convert_kwargs = convert_kwargs.copy()
207-
if hf_model.model_info.is_moe_model:
204+
if args.model_info.is_moe_model:
208205
current_convert_kwargs['moe_grouped_gemm'] = True
209206
megatron_args = MegatronArguments(
210207
**kwargs,
211208
**current_convert_kwargs,
212209
load=args.mcore_model,
210+
save=args.output_dir if args.to_mcore else None,
213211
adapter_load=args.mcore_adapters[0] if args.mcore_adapters else None,
214212
torch_dtype=args.torch_dtype)
215213
patch_megatron_tokenizer(processor)
@@ -228,17 +226,28 @@ def convert_mcore2hf(args: ExportArguments) -> None:
228226
logger.info('Merge LoRA...')
229227
mg_model = peft_model.merge_and_unload()
230228
logger.info('Megatron model created successfully.')
231-
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
232-
if args.test_convert_precision:
233-
test_convert_precision(hf_model, mg_model, template)
234-
logger.info('Successfully transferred MG model weights to HF model.')
235-
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
236-
save_checkpoint(
237-
hf_model,
238-
processor,
239-
args.output_dir,
240-
safe_serialization=args.safe_serialization,
241-
model_dirs=[ckpt_dir, args.model_dir],
242-
max_shard_size=args.max_shard_size,
243-
additional_saved_files=hf_model.model_meta.additional_saved_files)
244-
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
229+
if args.to_hf:
230+
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
231+
if args.test_convert_precision:
232+
test_convert_precision(hf_model, mg_model, template)
233+
del mg_model
234+
logger.info('Successfully transferred MG model weights to HF model.')
235+
ckpt_dir = megatron_args.load if megatron_args.adapter_load is None else megatron_args.adapter_load
236+
save_checkpoint(
237+
hf_model,
238+
processor,
239+
args.output_dir,
240+
safe_serialization=args.safe_serialization,
241+
model_dirs=[ckpt_dir, args.model_dir],
242+
max_shard_size=args.max_shard_size,
243+
additional_saved_files=hf_model.model_meta.additional_saved_files)
244+
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
245+
elif args.to_mcore:
246+
if args.thread_count is None:
247+
checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9
248+
args.thread_count = max(math.ceil(checkpoint_size / 10), 2) # 10GB
249+
patch_torch_dist_shard(args.thread_count)
250+
251+
args.save_args()
252+
mg_save_checkpoint(1, [mg_model], None, None, 0)
253+
logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')

0 commit comments

Comments
 (0)