Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,207 changes: 2,207 additions & 0 deletions diff_output.txt

Large diffs are not rendered by default.

917 changes: 902 additions & 15 deletions lzero/entry/train_unizero_multitask_segment_ddp.py

Large diffs are not rendered by default.

1,003 changes: 1,003 additions & 0 deletions lzero/entry/utils.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def default_config(cls: type) -> EasyDict:
cfg.cfg_type = cls.__name__ + 'Dict'
return cfg

def __init__(self, cfg: EasyDict = None) -> None:
def __init__(self, cfg: EasyDict = None,eval=False) -> None:
"""
Overview:
Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
Expand All @@ -56,9 +56,13 @@ def __init__(self, cfg: EasyDict = None) -> None:
default_config = self.default_config()
default_config.update(cfg)
self._cfg = default_config
if eval:
self._cfg.num_simulations=self._cfg.eval_num_simulations

self.inverse_scalar_transform_handle = InverseScalarTransform(
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
)


@classmethod
def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree":
Expand Down
62 changes: 62 additions & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,68 @@ def remove_hooks(self):
self.forward_handler.remove()
self.backward_handler.remove()

# # modified by tangjia
# class ModelGradientHook:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除不用的注释



# def __init__(self):
# """
# Overview:
# Class to capture gradients at model output.
# """
# self.output_grads = []

# def setup_hook(self, model):
# # Hook to capture gradients at model output
# self.backward_handler = model.register_full_backward_hook(self.backward_hook)

# def backward_hook(self, module, grad_input, grad_output):
# with torch.no_grad():
# # 保存输出梯度
# if grad_output[0] is not None:
# self.output_grads.append(grad_output[0].clone())

# def analyze(self):
# if not self.output_grads:
# return None

# # Calculate norms of output gradients
# grad_norms = [torch.norm(g, p=2, dim=1).mean() for g in self.output_grads]
# avg_grad_norm = torch.mean(torch.stack(grad_norms))
# max_grad_norm = torch.max(torch.stack(grad_norms))
# min_grad_norm = torch.min(torch.stack(grad_norms))

# # Clear stored data and delete tensors to free memory
# self.clear_data()

# # Optionally clear CUDA cache
# if torch.cuda.is_available():
# torch.cuda.empty_cache()

# return avg_grad_norm, max_grad_norm, min_grad_norm

# def clear_data(self):
# del self.output_grads[:]

# def remove_hooks(self):
# self.backward_handler.remove()

# 使用示例
# monitor = ModelGradientMonitor()
# monitor.setup_hook(model)
#
# # 训练过程中...
# loss.backward()
#
# # 获取梯度信息
# grad_norm = monitor.get_gradient_norm()
# grad_stats = monitor.get_gradient_stats()
#
# # 清理数据
# monitor.clear_data()
#
# # 训练结束后移除hook
# monitor.remove_hook()

class DownSample(nn.Module):

Expand Down
44 changes: 39 additions & 5 deletions lzero/model/unizero_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from easydict import EasyDict

from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook #,ModelGradientHook
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model_multitask import WorldModelMT

Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
))
elif world_model_cfg.encoder_type == "vit":
for task_id in range(1): # TODO: one share encoder
if world_model_cfg.task_num <=8:
if world_model_cfg.task_num ==1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么单独加一个task_num ==1的呢?

# # vit base
# self.representation_network.append(ViT(
# image_size =observation_shape[1],
Expand All @@ -144,12 +144,41 @@ def __init__(
patch_size = 8,
num_classes = obs_act_embed_dim,
dim = 768,
depth = 6,
heads = 6,
mlp_dim = 2048,
depth = 12,
heads = 12,
mlp_dim = 3072,
dropout = 0.1,
emb_dropout = 0.1,
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,

))
elif world_model_cfg.task_num <=8:
# # vit base
# self.representation_network.append(ViT(
# image_size =observation_shape[1],
# patch_size = 8,
# num_classes = obs_act_embed_dim,
# dim = 768,
# depth = 12,
# heads = 12,
# mlp_dim = 3072,
# dropout = 0.1,
# emb_dropout = 0.1,
# final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,
# ))
# vit small
self.representation_network.append(ViT(
image_size =observation_shape[1],
patch_size = 8,
num_classes = obs_act_embed_dim,
dim = 768,
depth = 12,
heads = 12,
mlp_dim = 3072,
dropout = 0.1,
emb_dropout = 0.1,
final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder,

))
elif world_model_cfg.task_num > 8:
# vit base
Expand Down Expand Up @@ -189,6 +218,11 @@ def __init__(
self.encoder_hook = FeatureAndGradientHook()
self.encoder_hook.setup_hooks(self.representation_network)

# if True: # Fixme: for debug
# # 增加对encoder的hook,监控传播到encoder 上的梯度
# self.encoder_output_hook = ModelGradientHook()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除 或者保留,增加标准英文注释

# self.encoder_output_hook.setup_hook(self.representation_network)

self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type)
self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
Expand Down
1 change: 0 additions & 1 deletion lzero/model/unizero_world_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .transformer import Transformer, TransformerConfig
143 changes: 117 additions & 26 deletions lzero/model/unizero_world_models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
from simple_parsing.helpers import Serializable
from torch import nn

import torch.distributed as dist
from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear

# _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward")
Expand Down Expand Up @@ -59,7 +59,8 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert
self.num_experts_per_tok = num_experts_per_tok
self.gate = gate
self.experts = nn.ModuleList(experts)

self.config=config

# 如果配置中指定了共享专家数量,则构建共享专家分支
if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0:
self.shared_expert = nn.Sequential(
Expand All @@ -69,34 +70,54 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert
)
else:
self.shared_expert = None

# GPU内存专家选择统计收集器 - 多粒度滑动窗口
self.device = next(iter(experts)).w1.weight.device if experts else torch.device('cuda')

# 滑动窗口配置
self.window_sizes = {
'immediate': 100, # 即时统计 (最近100步)
'short': 1000, # 短期统计 (最近1000步)
'medium': 10000, # 中期统计 (最近10000步)
'long': 100000 # 长期统计 (最近100000步)
}

# GPU统计缓冲区:任务ID -> {窗口类型 -> [专家选择历史]}
self.expert_stats_gpu = {}
self.step_count = 0

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, task_id: int = None) -> torch.Tensor:
# 保存原始形状后将 x reshape 为二维张量: [batch_size * seq_len, dim]
original_shape = x.size()
x = x.view(-1, self.dim)

# 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量
gate_logits = self.gate(x)
# 选取每个 token 得分最高的 k 个专家
weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1)
# 对选中的 logits 做 softmax,获得归一化权重
weights = F.softmax(weights, dim=1).to(x.dtype)

# 初始化存放专家计算输出的张量
expert_output = torch.zeros_like(x)

# 遍历所有专家,对被该专家选择的 token 分支进行计算
for expert_id in range(self.num_experts):
# 通过 where 找到 indices 中等于当前 expert_id 的 token 索引
batch_idx, expert_tok_idx = torch.where(indices == expert_id)
if batch_idx.numel() == 0:
continue
token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim]
# 调用当前专家模块计算输出
output_expert = self.experts[expert_id](token_subset)
# 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok]
token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1)
expert_output[batch_idx] += output_expert * token_weights
expert_output=x
if self.num_experts!=0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么增加self.num_experts!=0呢?self.num_experts不是一定大于0吗

# 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量
gate_logits = self.gate(x)
# 选取每个 token 得分最高的 k 个专家
weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1)
# 对选中的 logits 做 softmax,获得归一化权重
weights = F.softmax(weights, dim=1).to(x.dtype)

# 收集专家选择统计(仅在训练模式且有task_id时)
if self.training and task_id is not None:
self._collect_expert_selection_stats(task_id, indices)

# 初始化存放专家计算输出的张量
expert_output = torch.zeros_like(x)

# 遍历所有专家,对被该专家选择的 token 分支进行计算
for expert_id in range(self.num_experts):
# 通过 where 找到 indices 中等于当前 expert_id 的 token 索引
batch_idx, expert_tok_idx = torch.where(indices == expert_id)
if batch_idx.numel() == 0:
continue
token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim]
# 调用当前专家模块计算输出
output_expert = self.experts[expert_id](token_subset)
# 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok]
token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1)
expert_output[batch_idx] += output_expert * token_weights

# 如果使用了共享专家分支,则加上其输出
if self.shared_expert is not None:
Expand All @@ -107,6 +128,76 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

# 恢复原始形状后返回结果
return output.view(original_shape)

def _collect_expert_selection_stats(self, task_id: int, indices: torch.Tensor):
"""GPU内存收集专家选择统计 - 多粒度滑动窗口"""
self.step_count += 1

if task_id not in self.expert_stats_gpu:
self.expert_stats_gpu[task_id] = {}
for window_type in self.window_sizes.keys():
self.expert_stats_gpu[task_id][window_type] = torch.zeros(
self.window_sizes[window_type],
self.num_experts,
dtype=torch.float32,
device=self.device
)

# 计算当前批次每个专家的选择频次
indices_flat = indices.flatten() # [N*k]
expert_counts = torch.zeros(self.num_experts, device=self.device, dtype=torch.float32)
for expert_id in range(self.num_experts):
expert_counts[expert_id] = (indices_flat == expert_id).sum().float()

# 更新所有粒度的滑动窗口
for window_type, window_size in self.window_sizes.items():
buffer = self.expert_stats_gpu[task_id][window_type]
# 滑动窗口:新数据放到最后,旧数据向前移动
buffer[:-1] = buffer[1:].clone()
buffer[-1] = expert_counts

def get_expert_selection_stats(self, task_id: int = None):
"""获取多粒度专家选择频率统计 - 简化版本:直接返回当前数据"""
if task_id is None:
# 返回所有任务的统计
all_stats = {}
for tid in self.expert_stats_gpu.keys():
all_stats[tid] = self._compute_task_stats(tid)
return all_stats
else:
# 返回指定任务的统计
return self._compute_task_stats(task_id)

def _compute_task_stats(self, task_id: int):
"""计算指定任务的多粒度统计"""
if task_id not in self.expert_stats_gpu:
return {}

stats = {}
for window_type, buffer in self.expert_stats_gpu[task_id].items():
# 简化版本:直接对所有已有数据求平均,不考虑窗口是否填满
# buffer shape: [window_size, num_experts]
total_counts = buffer.sum(dim=0) # [num_experts]
total_selections = total_counts.sum()

if total_selections > 0:
frequencies = total_counts / total_selections
else:
frequencies = torch.zeros(self.num_experts, device=self.device)

stats[window_type] = {
'frequencies': frequencies, # 保持tensor格式
'total_counts': total_counts, # 保持tensor格式
'total_selections': total_selections.item(),
'data_points': min(self.step_count, self.window_sizes[window_type])
}

return stats

def reset_expert_selection_stats(self):
"""重置专家选择统计"""
self.expert_stats_gpu.clear()
self.step_count = 0

class MoELayerOptimized(nn.Module):
r"""
Expand Down
Loading