Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from functools import partial
from typing import Tuple, Optional, List
import concurrent.futures

import torch
import numpy as np
Expand All @@ -13,17 +14,24 @@
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler

from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from ding.utils import EasyTimer
import torch.nn.functional as F

import sys
import os
import torch.distributed as dist
# Import MOE statistics functions from utils
from lzero.entry.utils import (
collect_and_log_moe_statistics,
TemperatureScheduler,
log_buffer_memory_usage
)

# ------------------------------------------------------------
# 1. 额外增加 learner 专用 process-group
# 1. 额外增加 learner 专用 process-group
# (在 main / learner 初始化时调用一次)
# ------------------------------------------------------------
def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup:
Expand Down Expand Up @@ -367,7 +375,9 @@ def train_unizero_multitask_segment_ddp(
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
benchmark_name: str = "atari"
benchmark_name: str = "atari",
finetune_components=[],
cal_moe_profile: bool = True # 新增:控制MOE性能监控的开关
) -> 'Policy':
"""
Overview:
Expand Down Expand Up @@ -520,20 +530,25 @@ def train_unizero_multitask_segment_ddp(

# 编译配置
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# 创建共享的policy
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}')
tb_logger = SummaryWriter(log_dir)
cfg.policy.logger=tb_logger

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # MOE
policy.logger=tb_logger


# 加载预训练模型(如果提供)
if model_path is not None:
logging.info(f'开始加载模型: {model_path}')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device),finetune_components=finetune_components)
logging.info(f'完成加载模型: {model_path}')

# 创建TensorBoard日志记录器
log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}')
tb_logger = SummaryWriter(log_dir)

# 创建共享的learner
# 创建共享的learner #todo: cfg.policy.learn.learner.hook.log_show_after_iter
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

policy_config = cfg.policy
Expand Down Expand Up @@ -645,6 +660,7 @@ def train_unizero_multitask_segment_ddp(
# if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug
# if evaluator.should_eval(learner.train_iter):
print('=' * 20)

print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...')

# =========TODO=========
Expand Down Expand Up @@ -720,7 +736,7 @@ def train_unizero_multitask_segment_ddp(
print(f"not_enough_data:{not_enough_data}")
# 获取当前温度
current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter)

# if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0 :
if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0 :

Expand Down Expand Up @@ -811,7 +827,12 @@ def train_unizero_multitask_segment_ddp(
# 在训练时,DDP会自动同步梯度和参数
log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs)

# logging.error(f'Rank {rank}: one learn step done')
# +++++++++++++++++++++++++++++++++ MOE expert selection statistics logging +++++++++++++++++++++++++++++++++
if cal_moe_profile and cfg.policy.model.world_model_cfg.multiplication_moe_in_transformer and cfg.policy.model.world_model_cfg.num_experts_of_moe_in_transformer:
# Control MoE statistics logging frequency
moe_log_interval = getattr(cfg.policy, 'moe_log_interval', 1) # Default: log once every 500 iterations
if learner.train_iter % moe_log_interval == 0:
collect_and_log_moe_statistics(policy, tb_logger, learner.train_iter, world_size, rank)

# 判断是否需要计算task_exploitation_weight
if i == 0:
Expand Down
Loading