diff --git a/lzero/entry/train_unizero_multitask_segment_ddp.py b/lzero/entry/train_unizero_multitask_segment_ddp.py index 3fdcfa099..0095b8ce4 100644 --- a/lzero/entry/train_unizero_multitask_segment_ddp.py +++ b/lzero/entry/train_unizero_multitask_segment_ddp.py @@ -2,6 +2,7 @@ import os from functools import partial from typing import Tuple, Optional, List +import concurrent.futures import torch import numpy as np @@ -13,17 +14,24 @@ from ding.worker import BaseLearner from tensorboardX import SummaryWriter -from lzero.entry.utils import log_buffer_memory_usage, TemperatureScheduler + from lzero.policy import visit_count_temperature from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector from ding.utils import EasyTimer import torch.nn.functional as F - +import sys +import os import torch.distributed as dist +# Import MOE statistics functions from utils +from lzero.entry.utils import ( + collect_and_log_moe_statistics, + TemperatureScheduler, + log_buffer_memory_usage +) # ------------------------------------------------------------ -# 1. 额外增加 learner 专用 process-group +# 1. 额外增加 learner 专用 process-group # (在 main / learner 初始化时调用一次) # ------------------------------------------------------------ def build_learner_group(learner_ranks: list[int]) -> dist.ProcessGroup: @@ -367,7 +375,9 @@ def train_unizero_multitask_segment_ddp( model_path: Optional[str] = None, max_train_iter: Optional[int] = int(1e10), max_env_step: Optional[int] = int(1e10), - benchmark_name: str = "atari" + benchmark_name: str = "atari", + finetune_components=[], + cal_moe_profile: bool = False # 新增:控制MOE性能监控的开关 ) -> 'Policy': """ Overview: @@ -520,20 +530,23 @@ def train_unizero_multitask_segment_ddp( # 编译配置 cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) - # 创建共享的policy - policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') + tb_logger = SummaryWriter(log_dir) + cfg.policy.logger=tb_logger + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) # MOE # 加载预训练模型(如果提供) if model_path is not None: logging.info(f'开始加载模型: {model_path}') - policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device),finetune_components=finetune_components) logging.info(f'完成加载模型: {model_path}') # 创建TensorBoard日志记录器 log_dir = os.path.join('./{}/log'.format(cfg.exp_name), f'serial_rank_{rank}') tb_logger = SummaryWriter(log_dir) - # 创建共享的learner + # 创建共享的learner #todo: cfg.policy.learn.learner.hook.log_show_after_iter learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) policy_config = cfg.policy @@ -645,6 +658,7 @@ def train_unizero_multitask_segment_ddp( # if learner.train_iter > 10 and evaluator.should_eval(learner.train_iter): # only for debug # if evaluator.should_eval(learner.train_iter): print('=' * 20) + print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...') # =========TODO========= @@ -720,7 +734,7 @@ def train_unizero_multitask_segment_ddp( print(f"not_enough_data:{not_enough_data}") # 获取当前温度 current_temperature_task_weight = temperature_scheduler.get_temperature(learner.train_iter) - + # if learner.train_iter == 0 or learner.train_iter % cfg.policy.eval_freq == 0 : if learner.train_iter > 10 and learner.train_iter % cfg.policy.eval_freq == 0 : @@ -811,7 +825,12 @@ def train_unizero_multitask_segment_ddp( # 在训练时,DDP会自动同步梯度和参数 log_vars = learner.train(train_data_multi_task, envstep_multi_task, policy_kwargs=learn_kwargs) - # logging.error(f'Rank {rank}: one learn step done') + # +++++++++++++++++++++++++++++++++ MOE expert selection statistics logging +++++++++++++++++++++++++++++++++ + if cal_moe_profile and cfg.policy.model.world_model_cfg.multiplication_moe_in_transformer and cfg.policy.model.world_model_cfg.num_experts_of_moe_in_transformer: + # Control MoE statistics logging frequency + moe_log_interval = getattr(cfg.policy, 'moe_log_interval', 500) # Default: log once every 500 iterations + if learner.train_iter % moe_log_interval == 0: + collect_and_log_moe_statistics(policy, tb_logger, learner.train_iter, world_size, rank) # 判断是否需要计算task_exploitation_weight if i == 0: diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index b51eb7f11..60c0d7631 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -1,5 +1,8 @@ import os -from typing import Optional, Callable, Union, List, Tuple +import time +from typing import Optional, Callable, Union, List, Tuple, Dict +from io import BytesIO +import concurrent.futures import psutil import torch @@ -7,12 +10,11 @@ from pympler.asizeof import asizeof from tensorboardX import SummaryWriter - -import torch import numpy as np -import torch import torch.nn.functional as F import matplotlib.pyplot as plt +import seaborn as sns +from PIL import Image # ============================================================ # freeze_non_lora.py @@ -362,3 +364,962 @@ def log_buffer_run_time(train_iter: int, buffer: "GameBuffer", writer: SummaryWr # Reset the time records in the buffer. buffer.reset_runtime_metrics() + + +# ============================================================ +# MOE Expert Selection Statistics Functions +# ============================================================ + +# Global heatmap figure cache to avoid repeated creation +_GLOBAL_HEATMAP_FIG = None +_GLOBAL_HEATMAP_AX = None + + +def merge_expert_stats_across_ranks(all_expert_stats): + """ + Overview: + Merge expert selection statistics data from all distributed training ranks. + Combines statistics from different GPU processes for comprehensive analysis. + Arguments: + - all_expert_stats (:obj:`list`): List of expert statistics from all ranks. + Returns: + - merged_stats (:obj:`dict`): Merged statistics dictionary with structure + {task_id: {window_type: stats}}. + Examples: + >>> stats_list = [rank0_stats, rank1_stats, rank2_stats] + >>> merged = merge_expert_stats_across_ranks(stats_list) + >>> print(f"Merged {len(merged)} tasks") + """ + merged_stats = {} # {task_id: {window_type: stats}} + + for rank_expert_stats in all_expert_stats: + if rank_expert_stats: + for task_id, task_stats in rank_expert_stats.items(): + if task_id not in merged_stats: + merged_stats[task_id] = {} + + for window_type, stats in task_stats.items(): + # Only process statistics with actual data (tasks handled by current GPU) + if stats and stats.get('total_selections', 0) > 0: + merged_stats[task_id][window_type] = { + 'frequencies': np.array(stats['frequencies']), + 'total_selections': stats['total_selections'], + 'data_points': stats['data_points'] + } + return merged_stats + + +def _get_or_create_heatmap_figure(figsize): + """ + Overview: + Get or create a reusable heatmap figure for memory efficiency. + Maintains global figure cache to reduce memory allocation overhead. + Arguments: + - figsize (:obj:`tuple`): Figure size as (width, height). + Returns: + - fig (:obj:`matplotlib.figure.Figure`): Matplotlib figure object. + - ax (:obj:`matplotlib.axes.Axes`): Matplotlib axes object. + Examples: + >>> fig, ax = _get_or_create_heatmap_figure((10, 8)) + >>> ax.plot([1, 2, 3], [4, 5, 6]) + """ + global _GLOBAL_HEATMAP_FIG, _GLOBAL_HEATMAP_AX + if _GLOBAL_HEATMAP_FIG is None: + _GLOBAL_HEATMAP_FIG, _GLOBAL_HEATMAP_AX = plt.subplots(figsize=figsize) + else: + # Clear previous content + _GLOBAL_HEATMAP_AX.clear() + # Adjust image size + _GLOBAL_HEATMAP_FIG.set_size_inches(figsize) + return _GLOBAL_HEATMAP_FIG, _GLOBAL_HEATMAP_AX + + +def create_heatmap_with_values_fast(matrix, task_ids, title="Task-Expert Selection Frequencies"): + """ + Overview: + Efficiently create annotated blue-themed heatmap with performance optimizations. + Optimizations include matplotlib figure reuse, selective value annotations, + optimized image conversion pipeline, and reduced DPI for faster computation. + Arguments: + - matrix (:obj:`numpy.ndarray`): Input matrix for heatmap visualization. + - task_ids (:obj:`list`): List of task identifiers for y-axis labels. + - title (:obj:`str`, optional): Heatmap title. Default is "Task-Expert Selection Frequencies". + Returns: + - img_array (:obj:`numpy.ndarray`): Image array in CHW format for TensorBoard logging. + Shapes: + - matrix: :math:`(N_{tasks}, N_{experts})` where N_tasks and N_experts are dimensions. + - img_array: :math:`(3, H, W)` where H and W are image height and width. + Examples: + >>> import numpy as np + >>> matrix = np.random.rand(5, 8) + >>> task_ids = [0, 1, 2, 3, 4] + >>> heatmap = create_heatmap_with_values_fast(matrix, task_ids) + >>> print(f"Heatmap shape: {heatmap.shape}") # (3, height, width) + """ + try: + figsize = (max(6, matrix.shape[1]), max(4, matrix.shape[0])) + fig, ax = _get_or_create_heatmap_figure(figsize) + + # Intelligently choose whether to display value annotations + show_annot = matrix.size <= 64 # Only display values for 8x8 or smaller matrices + + # Use matplotlib directly to avoid seaborn overhead + im = ax.imshow(matrix, cmap='Blues', aspect='auto') + + # Selectively add value annotations + if show_annot: + for i in range(matrix.shape[0]): + for j in range(matrix.shape[1]): + value = matrix[i, j] + color = 'white' if value > 0.5 else 'black' + ax.text(j, i, f'{value:.3f}', ha='center', va='center', + color=color, fontsize=8) + + # Set labels and title + ax.set_xticks(range(matrix.shape[1])) + ax.set_yticks(range(matrix.shape[0])) + ax.set_xticklabels([f'E{i}' for i in range(matrix.shape[1])], fontsize=10) + ax.set_yticklabels([f'T{tid}' for tid in task_ids], fontsize=10) + ax.set_title(title, fontsize=12, pad=15) + ax.set_xlabel('Experts', fontsize=10) + ax.set_ylabel('Tasks', fontsize=10) + + # Simplified colorbar + if not hasattr(fig, '_colorbar_created'): + plt.colorbar(im, ax=ax, label='Frequency') + fig._colorbar_created = True + + # Optimized image conversion: using lower DPI and simplified pipeline + fig.canvas.draw() + try: + # Get RGB data directly from canvas + if hasattr(fig.canvas, 'buffer_rgba'): + buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + img_array = buf[:, :, :3] # Remove alpha channel + else: + buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img_array = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + + # Convert to CHW format + img_array = img_array.transpose(2, 0, 1) + + except Exception: + # Fallback: create simple blue gradient matrix + h, w = matrix.shape + img_array = np.zeros((3, h*20, w*20), dtype=np.uint8) + # Simple matrix upscaling and mapping to blue channel + matrix_resized = np.repeat(np.repeat(matrix, 20, axis=0), 20, axis=1) + img_array[2] = (matrix_resized * 255).astype(np.uint8) + + return img_array + + except Exception as e: + print(f"Warning: Heatmap generation failed: {e}, using fallback") + # Ultimate fallback: return blank image + return np.zeros((3, 100, 100), dtype=np.uint8) + + +def create_heatmap_with_values(matrix, task_ids, title="Task-Expert Selection Frequencies"): + """ + Overview: + Create annotated blue-themed heatmap using seaborn - original version for fallback. + This function serves as a backup when the optimized version encounters issues. + Arguments: + - matrix (:obj:`numpy.ndarray`): Input matrix for heatmap visualization. + - task_ids (:obj:`list`): List of task identifiers for y-axis labels. + - title (:obj:`str`, optional): Heatmap title. Default is "Task-Expert Selection Frequencies". + Returns: + - img_array (:obj:`numpy.ndarray`): Image array in CHW format for TensorBoard logging. + Shapes: + - matrix: :math:`(N_{tasks}, N_{experts})` where N_tasks and N_experts are dimensions. + - img_array: :math:`(3, H, W)` where H and W are image height and width. + Examples: + >>> import numpy as np + >>> matrix = np.random.rand(5, 8) + >>> task_ids = [0, 1, 2, 3, 4] + >>> heatmap = create_heatmap_with_values(matrix, task_ids) + >>> print(f"Heatmap shape: {heatmap.shape}") # (3, height, width) + """ + fig, ax = plt.subplots(figsize=(max(8, matrix.shape[1]), max(6, matrix.shape[0]))) + + # Use blue color scheme + sns.heatmap(matrix, + annot=True, # Display values + fmt='.3f', # Value format + cmap='Blues', # Blue theme + ax=ax, + cbar_kws={'label': 'Selection Frequency'}, + xticklabels=[f'Expert{i}' for i in range(matrix.shape[1])], + yticklabels=[f'Task{tid}' for tid in task_ids]) + + ax.set_title(title, fontsize=14, pad=20) + ax.set_xlabel('Experts', fontsize=12) + ax.set_ylabel('Tasks', fontsize=12) + + plt.tight_layout() + + # Save to BytesIO + buf = BytesIO() + plt.savefig(buf, format='png', dpi=100, bbox_inches='tight') + buf.seek(0) + + # Convert to numpy array for tensorboard + img = Image.open(buf) + img_array = np.array(img) + buf.close() + plt.close(fig) + + # Convert to CHW format (Channel, Height, Width) + if len(img_array.shape) == 3: + img_array = img_array.transpose(2, 0, 1) + + return img_array + + +def log_expert_selection_details(tb_logger, merged_stats, valid_task_ids, matrix, window_type, train_iter): + """ + Overview: + Log detailed expert selection statistics for each task. + Records frequency entropy, variance, and total selections for analysis. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged expert selection statistics across ranks. + - valid_task_ids (:obj:`list`): List of valid task identifiers. + - matrix (:obj:`numpy.ndarray`): Expert selection frequency matrix. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_expert_selection_details(tb_logger, stats, [0,1,2], matrix, 'immediate', 1000) + """ + for i, task_id in enumerate(valid_task_ids): + frequencies = matrix[i] + stats = merged_stats[task_id][window_type] + + # Calculate and record task expert selection entropy (uniformity metric) + task_frequencies = np.array(frequencies) + task_frequencies = task_frequencies + 1e-8 # Avoid log(0) + task_entropy = -np.sum(task_frequencies * np.log(task_frequencies)) + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/ExpertSelectionEntropy', + task_entropy, global_step=train_iter + ) + + # Record task expert selection variance (dispersion) + expert_variance = np.var(task_frequencies) + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/ExpertSelectionVariance', + expert_variance, global_step=train_iter + ) + + # Record task-level summary statistics + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/TotalSelections', + stats['total_selections'], global_step=train_iter + ) + tb_logger.add_scalar( + f'MOE_Details/Task{task_id}_{window_type}/DataPoints', + stats['data_points'], global_step=train_iter + ) + + +def log_global_moe_statistics(tb_logger, matrix, window_type, valid_task_ids, train_iter): + """ + Overview: + Log global MOE statistics including expert usage uniformity and extremes. + Provides system-wide view of expert utilization patterns. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - matrix (:obj:`numpy.ndarray`): Expert selection frequency matrix. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - valid_task_ids (:obj:`list`): List of valid task identifiers. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_global_moe_statistics(tb_logger, matrix, 'immediate', [0,1,2], 1000) + """ + # Record basic information + tb_logger.add_scalar( + f'MOE_Global/{window_type}/NumActiveTasks', + len(valid_task_ids), global_step=train_iter + ) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/NumExperts', + matrix.shape[1], global_step=train_iter + ) + + # Calculate expert usage uniformity + expert_avg_usage = np.mean(matrix, axis=0) # Average usage frequency per expert + usage_entropy = -np.sum(expert_avg_usage * np.log(expert_avg_usage + 1e-8)) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/ExpertUsageEntropy', + usage_entropy, global_step=train_iter + ) + + # Record most and least used experts + most_used_expert = np.argmax(expert_avg_usage) + least_used_expert = np.argmin(expert_avg_usage) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/MostUsedExpert', + most_used_expert, global_step=train_iter + ) + tb_logger.add_scalar( + f'MOE_Global/{window_type}/LeastUsedExpert', + least_used_expert, global_step=train_iter + ) + + +def process_and_log_moe_heatmaps_fast(tb_logger, merged_stats, window_type, train_iter): + """ + Overview: + Efficiently process and log MOE heatmaps with performance optimizations. + Includes vectorized data processing, conditional heatmap generation, + and batch statistical processing. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged expert selection statistics across ranks. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> process_and_log_moe_heatmaps_fast(tb_logger, stats, 'immediate', 1000) + """ + # Quick filtering of valid tasks + valid_task_data = [(tid, stats[window_type]['frequencies']) + for tid, stats in merged_stats.items() + if window_type in stats] + + if not valid_task_data: + return + + # Vectorized matrix construction + valid_task_ids, frequencies_list = zip(*valid_task_data) + matrix = np.array(frequencies_list) + + # Conditional heatmap generation: only for small matrices + if matrix.size <= 200: # Only generate heatmap when tasks*experts <= 200 + try: + heatmap_img = create_heatmap_with_values_fast( + matrix, valid_task_ids, + f'MOE {window_type} Task-Expert Selection' + ) + + # Log heatmap to tensorboard + tb_logger.add_image( + f'MOE_Heatmap/{window_type}_TaskExpert_Heatmap', + heatmap_img, + global_step=train_iter, + dataformats='CHW' + ) + except Exception as e: + print(f"Warning: Heatmap generation failed: {e}") + + # Always log statistical data (lightweight operation) + log_expert_selection_details(tb_logger, merged_stats, valid_task_ids, matrix, window_type, train_iter) + log_global_moe_statistics(tb_logger, matrix, window_type, valid_task_ids, train_iter) + + +def process_and_log_moe_heatmaps(tb_logger, merged_stats, window_type, train_iter): + """ + Overview: + Process and log MOE heatmaps - original version for fallback. + This function serves as a backup when the optimized version encounters issues. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged expert selection statistics across ranks. + - window_type (:obj:`str`): Time window type (immediate, short, medium, long). + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> process_and_log_moe_heatmaps(tb_logger, stats, 'immediate', 1000) + """ + all_task_ids = sorted(merged_stats.keys()) + task_expert_matrix = [] + valid_task_ids = [] + + # Collect frequency data from valid tasks + for task_id in all_task_ids: + if window_type in merged_stats[task_id]: + frequencies = merged_stats[task_id][window_type]['frequencies'] + task_expert_matrix.append(frequencies) + valid_task_ids.append(task_id) + + if not task_expert_matrix: + return + + # Convert to numpy matrix (num_tasks, num_experts) + matrix = np.array(task_expert_matrix) + + # Create annotated blue-themed heatmap + heatmap_img = create_heatmap_with_values( + matrix, valid_task_ids, + f'MOE {window_type} Task-Expert Selection Frequencies' + ) + + # Log heatmap to tensorboard + tb_logger.add_image( + f'MOE_Heatmap/{window_type}_TaskExpert_Heatmap', + heatmap_img, + global_step=train_iter, + dataformats='CHW' + ) + + # Log detailed and global statistics + log_expert_selection_details(tb_logger, merged_stats, valid_task_ids, matrix, window_type, train_iter) + + +def convert_stats_to_serializable(moe_stats): + """ + Overview: + Convert tensor data in MOE statistics to serializable numpy format. + Ensures compatibility with distributed communication protocols. + Arguments: + - moe_stats (:obj:`dict`): MOE statistics containing tensor data. + Returns: + - converted (:obj:`dict`): Converted statistics with numpy arrays. + Examples: + >>> tensor_stats = {'task_0': {'immediate': {'frequencies': torch.tensor([0.1, 0.9])}}} + >>> numpy_stats = convert_stats_to_serializable(tensor_stats) + >>> type(numpy_stats['task_0']['immediate']['frequencies']) # + """ + if not moe_stats: + return {} + + converted = {} + for task_id, task_stats in moe_stats.items(): + converted[task_id] = {} + for window_type, stats in task_stats.items(): + if stats and 'frequencies' in stats: + converted[task_id][window_type] = { + 'frequencies': stats['frequencies'].cpu().numpy().tolist(), + 'total_selections': stats['total_selections'], + 'data_points': stats['data_points'] + } + return converted + + +def gather_distributed_moe_stats(local_stats, world_size): + """ + Overview: + Gather MOE statistics from all GPUs in distributed training environment. + Handles communication failures gracefully with fallback to local statistics. + Arguments: + - local_stats (:obj:`dict`): Local GPU's MOE statistics. + - world_size (:obj:`int`): Total number of distributed training processes. + Returns: + - all_stats (:obj:`list`): List of statistics from all ranks. + Examples: + >>> local_data = {'task_0': {'immediate': {'frequencies': [0.1, 0.9]}}} + >>> all_data = gather_distributed_moe_stats(local_data, 4) + >>> len(all_data) # 4 (from 4 GPUs) + """ + all_stats = [None for _ in range(world_size)] + try: + dist.all_gather_object(all_stats, local_stats) + return all_stats + except Exception as e: + print(f"Distributed MOE statistics gathering failed: {e}") + return [local_stats] # fallback to local statistics + + +def collect_and_log_moe_statistics(policy, tb_logger, train_iter, world_size, rank): + """ + Overview: + Collect and log MOE expert selection statistics including heatmaps and distribution analysis. + Comprehensive function that handles distributed data collection, merging, and visualization. + Arguments: + - policy (:obj:`Policy`): Training policy object containing world model. + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - train_iter (:obj:`int`): Current training iteration number. + - world_size (:obj:`int`): Total number of GPUs in distributed training. + - rank (:obj:`int`): Current GPU rank identifier. + Examples: + >>> collect_and_log_moe_statistics(policy, tb_logger, 1000, 8, 0) + """ + try: + # Step 1: Get MOE statistics from policy's transformer model + moe_stats = None + + transformer = policy._model.world_model.transformer + if hasattr(transformer, 'get_expert_selection_stats'): + moe_stats = transformer.get_expert_selection_stats() + + if moe_stats is None: + print(f"Rank {rank}: Warning: Unable to get MOE statistics, train_iter={train_iter}") + return + + # Step 2: Convert tensor data to serializable format + serializable_stats = convert_stats_to_serializable(moe_stats) + + print(f"Rank {rank}: Local MOE statistics - tasks: {len(serializable_stats)}, train_iter={train_iter}") + + # Step 3: Gather statistics from all GPUs in distributed setting + all_expert_stats = gather_distributed_moe_stats(serializable_stats, world_size) + + # Step 4: Merge statistics data + merged_stats = merge_expert_stats_across_ranks(all_expert_stats) + + if not merged_stats: + print(f"Rank {rank}: Warning: Merged MOE statistics empty, train_iter={train_iter}") + return + + # Step 5: All GPUs log MOE statistics + print(f"Rank {rank}: Starting MOE statistics logging - merged tasks: {len(merged_stats)}, train_iter={train_iter}") + + # Generate heatmaps and statistics for each time window + for window_type in ['immediate', 'short', 'medium', 'long']: + if any(window_type in task_stats for task_stats in merged_stats.values()): + process_and_log_moe_heatmaps_fast(tb_logger, merged_stats, window_type, train_iter) + + # Log overall MOE usage + tb_logger.add_scalar('MOE_Global/ActiveTasks', len(merged_stats), global_step=train_iter) + + # Step 6: Add distribution difference computation and logging + if any('immediate' in task_stats for task_stats in merged_stats.values()): + print(f"Rank {rank}: Starting inter-task distribution difference calculation...") + collect_and_log_divergences_with_heatmaps(tb_logger, merged_stats, train_iter) + + print(f"Rank {rank}: MOE statistics logging completed, train_iter={train_iter}") + + except Exception as e: + print(f"Rank {rank}: MOE statistics collection failed - {e}, train_iter={train_iter}") + import traceback + traceback.print_exc() + + +# ====== GPU-Optimized Distribution Divergence Calculation and Visualization Functions ====== +def jensen_shannon_divergence_batch_gpu(distributions_tensor): + """ + Overview: + GPU batch computation of JS divergence matrix - fully vectorized, no loops. + Efficiently computes Jensen-Shannon divergence between all pairs of distributions. + Arguments: + - distributions_tensor (:obj:`torch.Tensor`): Shape (n_tasks, n_experts), GPU tensor. + Returns: + - js_matrix (:obj:`torch.Tensor`): Shape (n_tasks, n_tasks), symmetric matrix. + Shapes: + - distributions_tensor: :math:`(N_{tasks}, N_{experts})` + - js_matrix: :math:`(N_{tasks}, N_{tasks})` + Examples: + >>> dist_tensor = torch.rand(5, 8).cuda() + >>> js_matrix = jensen_shannon_divergence_batch_gpu(dist_tensor) + >>> print(js_matrix.shape) # torch.Size([5, 5]) + """ + device = distributions_tensor.device + n_tasks, n_experts = distributions_tensor.shape + + # 1. Normalize to probability distributions + eps = 1e-8 + distributions_tensor = distributions_tensor / (distributions_tensor.sum(dim=1, keepdim=True) + eps) + + # 2. Use broadcasting to compute average distributions for all task pairs + # P_i: (n_tasks, 1, n_experts), P_j: (1, n_tasks, n_experts) + P_i = distributions_tensor.unsqueeze(1) + P_j = distributions_tensor.unsqueeze(0) + M = 0.5 * (P_i + P_j) # shape: (n_tasks, n_tasks, n_experts) + + # 3. Batch compute KL divergences - fully vectorized + # KL(P_i || M) for all pairs + log_ratio_i = torch.log((P_i + eps) / (M + eps)) + kl_i_m = torch.sum(P_i * log_ratio_i, dim=2) # (n_tasks, n_tasks) + + # KL(P_j || M) for all pairs + log_ratio_j = torch.log((P_j + eps) / (M + eps)) + kl_j_m = torch.sum(P_j * log_ratio_j, dim=2) # (n_tasks, n_tasks) + + # 4. JS divergence matrix + js_matrix = 0.5 * (kl_i_m + kl_j_m) + + return js_matrix + + +def wasserstein_distance_batch_gpu(distributions_tensor): + """ + Overview: + GPU batch computation of Wasserstein distance matrix - efficient 1D distribution implementation. + Computes Earth Mover's Distance between all pairs of discrete distributions. + Arguments: + - distributions_tensor (:obj:`torch.Tensor`): Shape (n_tasks, n_experts), GPU tensor. + Returns: + - wasserstein_matrix (:obj:`torch.Tensor`): Shape (n_tasks, n_tasks), symmetric matrix. + Shapes: + - distributions_tensor: :math:`(N_{tasks}, N_{experts})` + - wasserstein_matrix: :math:`(N_{tasks}, N_{tasks})` + Examples: + >>> dist_tensor = torch.rand(5, 8).cuda() + >>> wass_matrix = wasserstein_distance_batch_gpu(dist_tensor) + >>> print(wass_matrix.shape) # torch.Size([5, 5]) + """ + device = distributions_tensor.device + n_tasks, n_experts = distributions_tensor.shape + eps = 1e-8 + + # 1. Normalize to probability distributions + distributions_tensor = distributions_tensor / (distributions_tensor.sum(dim=1, keepdim=True) + eps) + + # 2. Compute cumulative distribution functions (CDF) + cdf_tensor = torch.cumsum(distributions_tensor, dim=1) # (n_tasks, n_experts) + + # 3. Use broadcasting to compute L1 distances between all CDF pairs + cdf_i = cdf_tensor.unsqueeze(1) # (n_tasks, 1, n_experts) + cdf_j = cdf_tensor.unsqueeze(0) # (1, n_tasks, n_experts) + + # Wasserstein distance = L1 norm of cumulative distribution differences + wasserstein_matrix = torch.sum(torch.abs(cdf_i - cdf_j), dim=2) + + return wasserstein_matrix + + +def compute_distribution_divergences_optimized(merged_stats, window_type='immediate'): + """ + Overview: + GPU-optimized version for efficient distribution divergence computation. + Leverages GPU acceleration for batch processing of divergence metrics. + Arguments: + - merged_stats (:obj:`dict`): Merged MOE statistics from all distributed ranks. + - window_type (:obj:`str`, optional): Time window type. Default is 'immediate'. + Returns: + - divergence_data (:obj:`dict`): Comprehensive divergence analysis results including + matrices, statistics, and metadata. + Examples: + >>> stats = {'task_0': {'immediate': {'frequencies': [0.1, 0.9]}}} + >>> result = compute_distribution_divergences_optimized(stats) + >>> print(f"GPU accelerated: {result['gpu_accelerated']}") + """ + # 1. Data preprocessing + valid_tasks = [(tid, stats[window_type]['frequencies']) + for tid, stats in merged_stats.items() + if window_type in stats] + + if len(valid_tasks) < 2: + return {} + + task_ids, frequencies_list = zip(*valid_tasks) + + # 2. Efficient tensor conversion + try: + if isinstance(frequencies_list[0], torch.Tensor): + frequencies_tensor = torch.stack(frequencies_list) + else: + frequencies_tensor = torch.tensor( + np.array(frequencies_list), + dtype=torch.float32 + ) + + # Automatic GPU acceleration + if torch.cuda.is_available(): + frequencies_tensor = frequencies_tensor.cuda() + + except Exception as e: + print(f"GPU conversion failed, using CPU: {e}") + frequencies_tensor = torch.tensor(np.array(frequencies_list), dtype=torch.float32) + + device = frequencies_tensor.device + n_tasks, n_experts = frequencies_tensor.shape + + # 3. GPU batch computation (no loops) + with torch.no_grad(): + # Batch compute JS divergence and Wasserstein distance + js_matrix = jensen_shannon_divergence_batch_gpu(frequencies_tensor) + wasserstein_matrix = wasserstein_distance_batch_gpu(frequencies_tensor) + + # Efficiently extract upper triangular values (avoid duplicate computation) + triu_indices = torch.triu_indices(n_tasks, n_tasks, offset=1, device=device) + js_values = js_matrix[triu_indices[0], triu_indices[1]] + wasserstein_values = wasserstein_matrix[triu_indices[0], triu_indices[1]] + + # Statistical computation (vectorized) + js_stats = { + 'avg': torch.mean(js_values).item(), + 'max': torch.max(js_values).item(), + 'min': torch.min(js_values).item(), + 'std': torch.std(js_values).item() + } + + wasserstein_stats = { + 'avg': torch.mean(wasserstein_values).item(), + 'max': torch.max(wasserstein_values).item(), + 'min': torch.min(wasserstein_values).item(), + 'std': torch.std(wasserstein_values).item() + } + + return { + 'task_ids': task_ids, + 'n_tasks': n_tasks, + 'n_experts': n_experts, + 'device': str(device), + 'gpu_accelerated': 'cuda' in str(device), + + # Return CPU versions for logging + 'js_matrix': js_matrix.cpu().numpy(), + 'wasserstein_matrix': wasserstein_matrix.cpu().numpy(), + 'js_stats': js_stats, + 'wasserstein_stats': wasserstein_stats + } + + +def create_similarity_heatmap_no_diagonal(similarity_matrix, task_ids, metric_name, title_suffix=""): + """ + Overview: + Create task similarity heatmap with diagonal elements removed. + Provides clear visualization of inter-task relationships without self-similarity noise. + Arguments: + - similarity_matrix (:obj:`numpy.ndarray`): Similarity matrix (n_tasks, n_tasks). + - task_ids (:obj:`list`): Task identifier list for axis labels. + - metric_name (:obj:`str`): Metric name ('js_divergence', 'wasserstein_distance'). + - title_suffix (:obj:`str`, optional): Additional title suffix. Default is "". + Returns: + - img_array (:obj:`numpy.ndarray`): Image array in CHW format for TensorBoard. + Shapes: + - similarity_matrix: :math:`(N_{tasks}, N_{tasks})` + - img_array: :math:`(3, H, W)` where H and W are image dimensions. + Examples: + >>> matrix = np.random.rand(5, 5) + >>> task_ids = [0, 1, 2, 3, 4] + >>> heatmap = create_similarity_heatmap_no_diagonal(matrix, task_ids, 'js_divergence') + >>> print(f"Output shape: {heatmap.shape}") # (3, height, width) + """ + try: + # Copy matrix to avoid modifying original data + matrix = similarity_matrix.copy() + + # Set diagonal to NaN so matplotlib displays as blank + np.fill_diagonal(matrix, np.nan) + + figsize = (max(6, len(task_ids)), max(4, len(task_ids))) + fig, ax = plt.subplots(figsize=figsize) # Create new figure to avoid reuse issues + + # Choose color mapping based on metric type + if 'js' in metric_name.lower(): + cmap = 'Reds' + title_name = 'JS Divergence' + vmin, vmax = 0, 1.0 + else: # wasserstein + cmap = 'Blues' + title_name = 'Wasserstein Distance' + vmin, vmax = None, None # Adaptive + + # Use masked array to handle NaN values, diagonal displays as white + masked_matrix = np.ma.masked_invalid(matrix) + im = ax.imshow(masked_matrix, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto') + + # Add value annotations (skip diagonal) + if len(task_ids) <= 15: # Only add annotations for smaller task counts + for i in range(len(task_ids)): + for j in range(len(task_ids)): + if i != j: # Skip diagonal + value = matrix[i, j] + if not np.isnan(value): + threshold = (vmax or np.nanmax(matrix)) * 0.5 if vmax else np.nanmax(matrix) * 0.5 + color = 'white' if value > threshold else 'black' + ax.text(j, i, f'{value:.3f}', ha='center', va='center', + color=color, fontsize=8) + + # Set labels + ax.set_xticks(range(len(task_ids))) + ax.set_yticks(range(len(task_ids))) + ax.set_xticklabels([f'T{tid}' for tid in task_ids], fontsize=9) + ax.set_yticklabels([f'T{tid}' for tid in task_ids], fontsize=9) + ax.set_title(f'Task {title_name} Matrix {title_suffix} (No Diagonal)', fontsize=12) + ax.set_xlabel('Tasks', fontsize=10) + ax.set_ylabel('Tasks', fontsize=10) + + # Add colorbar + plt.colorbar(im, ax=ax, label=title_name, shrink=0.8) + + # Convert to image array - fix matplotlib version compatibility + fig.canvas.draw() + + try: + # New matplotlib uses buffer_rgba + if hasattr(fig.canvas, 'buffer_rgba'): + buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + h, w = fig.canvas.get_width_height() + img_array = buf.reshape(h, w, 4)[:, :, :3] # Remove alpha channel + else: + # Old matplotlib fallback + buf = fig.canvas.print_to_string() + img_array = np.frombuffer(buf, dtype=np.uint8) + h, w = fig.canvas.get_width_height() + img_array = img_array.reshape(h, w, 3) + except Exception as conv_e: + print(f"Image conversion method failed: {conv_e}, trying PIL approach") + # Final fallback: convert through PIL + buf = BytesIO() + fig.savefig(buf, format='png', dpi=100, bbox_inches='tight') + buf.seek(0) + img = Image.open(buf) + img_array = np.array(img)[:, :, :3] # Remove alpha channel + buf.close() + + img_array = img_array.transpose(2, 0, 1) # CHW format + plt.close(fig) # Close figure to avoid memory leak + + return img_array + + except Exception as e: + print(f"Warning: No-diagonal heatmap generation failed: {e}") + return np.zeros((3, 100, 100), dtype=np.uint8) + + +def log_pairwise_optimized(tb_logger, divergence_data, train_iter): + """ + Overview: + Optimized task pair logging with batch processing. + Efficiently logs pairwise divergence metrics for all task combinations. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - divergence_data (:obj:`dict`): Divergence computation results. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_pairwise_optimized(tb_logger, divergence_data, 1000) + """ + task_ids = divergence_data['task_ids'] + js_matrix = divergence_data['js_matrix'] + wasserstein_matrix = divergence_data['wasserstein_matrix'] + + # Batch construct task pair metric dictionary + pairwise_scalars = {} + + for i, task_i in enumerate(task_ids): + for j, task_j in enumerate(task_ids): + if i < j: # Only log upper triangle + # Construct metric names + js_key = f'TaskPairwise/Immediate_Task{task_i}_Task{task_j}_JS_Divergence' + wass_key = f'TaskPairwise/Immediate_Task{task_i}_Task{task_j}_Wasserstein_Distance' + + pairwise_scalars[js_key] = js_matrix[i, j] + pairwise_scalars[wass_key] = wasserstein_matrix[i, j] + + # Batch write to TensorBoard + for key, value in pairwise_scalars.items(): + tb_logger.add_scalar(key, float(value), global_step=train_iter) + + +def log_divergences_with_heatmaps(tb_logger, divergence_data, train_iter): + """ + Overview: + Log distribution divergence metrics and heatmaps (with diagonal removed). + Comprehensive logging of inter-task distribution analysis results. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - divergence_data (:obj:`dict`): Divergence computation results. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> log_divergences_with_heatmaps(tb_logger, divergence_data, 1000) + """ + if not divergence_data: + return + + js_stats = divergence_data['js_stats'] + wasserstein_stats = divergence_data['wasserstein_stats'] + task_ids = divergence_data['task_ids'] + n_tasks = divergence_data['n_tasks'] + + # Debug: Check matrix data + js_matrix = divergence_data['js_matrix'] + wasserstein_matrix = divergence_data['wasserstein_matrix'] + print(f"DEBUG: JS matrix shape={js_matrix.shape}, range=[{np.min(js_matrix):.6f}, {np.max(js_matrix):.6f}]") + print(f"DEBUG: Wasserstein matrix shape={wasserstein_matrix.shape}, range=[{np.min(wasserstein_matrix):.6f}, {np.max(wasserstein_matrix):.6f}]") + + # 1. Log scalar metrics + scalar_dict = { + 'MOE_Divergence/Immediate_AvgJS_Divergence': js_stats['avg'], + 'MOE_Divergence/Immediate_MaxJS_Divergence': js_stats['max'], + 'MOE_Divergence/Immediate_AvgWasserstein_Distance': wasserstein_stats['avg'], + 'MOE_Divergence/Immediate_MaxWasserstein_Distance': wasserstein_stats['max'], + } + + for key, value in scalar_dict.items(): + tb_logger.add_scalar(key, value, global_step=train_iter) + + # 1.1 Print core metrics to console + print("=" * 65) + print(f" Inter-Task Distribution Divergence Statistics (Iteration: {train_iter})") + print("=" * 65) + print(f"Participating tasks: {n_tasks} | Task IDs: {list(task_ids)}") + print(f"Computing device: {divergence_data.get('device', 'Unknown')} | GPU acceleration: {'Enabled' if divergence_data.get('gpu_accelerated', False) else 'Disabled'}") + print("-" * 65) + print("JS Divergence (Jensen-Shannon Divergence):") + print(f" Average: {js_stats['avg']:.6f} | Maximum: {js_stats['max']:.6f}") + print(f" Minimum: {js_stats['min']:.6f} | Std Dev: {js_stats['std']:.6f}") + print("-" * 65) + print("Wasserstein Distance:") + print(f" Average: {wasserstein_stats['avg']:.6f} | Maximum: {wasserstein_stats['max']:.6f}") + print(f" Minimum: {wasserstein_stats['min']:.6f} | Std Dev: {wasserstein_stats['std']:.6f}") + print("=" * 65) + + # 2. Log similarity matrix heatmaps with diagonal removed + task_ids = divergence_data['task_ids'] + n_tasks = divergence_data['n_tasks'] + + if n_tasks <= 25: # Limit matrix size to avoid oversized heatmaps + try: + # JS divergence matrix heatmap (no diagonal) + js_heatmap = create_similarity_heatmap_no_diagonal( + divergence_data['js_matrix'], + task_ids, + 'js_divergence', + f'(Immediate-{n_tasks} tasks)' + ) + tb_logger.add_image( + 'TaskSimilarity/Immediate_JS_Matrix_NoDiagonal', + js_heatmap, + global_step=train_iter, + dataformats='CHW' + ) + + # Wasserstein distance matrix heatmap (no diagonal) + wass_heatmap = create_similarity_heatmap_no_diagonal( + divergence_data['wasserstein_matrix'], + task_ids, + 'wasserstein_distance', + f'(Immediate-{n_tasks} tasks)' + ) + tb_logger.add_image( + 'TaskSimilarity/Immediate_Wasserstein_Matrix_NoDiagonal', + wass_heatmap, + global_step=train_iter, + dataformats='CHW' + ) + + except Exception as e: + print(f"Warning: Similarity matrix heatmap generation failed: {e}") + + # 3. Log task pair metrics (optional) + if n_tasks <= 20: + log_pairwise_optimized(tb_logger, divergence_data, train_iter) + + +def collect_and_log_divergences_with_heatmaps(tb_logger, merged_stats, train_iter): + """ + Overview: + Complete distribution divergence computation and logging (including no-diagonal heatmaps). + End-to-end pipeline for analyzing and visualizing inter-task distribution differences. + Arguments: + - tb_logger (:obj:`SummaryWriter`): TensorBoard logger for metric recording. + - merged_stats (:obj:`dict`): Merged MOE statistics from distributed training. + - train_iter (:obj:`int`): Current training iteration for logging. + Examples: + >>> collect_and_log_divergences_with_heatmaps(tb_logger, merged_stats, 1000) + """ + try: + # GPU-optimized computation + divergence_data = compute_distribution_divergences_optimized(merged_stats, 'immediate') + + if not divergence_data: + print(f"Skipping distribution divergence computation - insufficient tasks (need >=2 tasks)") + return + + # Log metrics and heatmaps + log_divergences_with_heatmaps(tb_logger, divergence_data, train_iter) + + # Summary print + print(f">> Distribution divergence statistics completed and logged to TensorBoard") + if divergence_data.get('n_tasks', 0) <= 25: + print(f">> Similarity matrix heatmaps generated (diagonal removed)") + if divergence_data.get('n_tasks', 0) <= 20: + print(f">> Task pair detailed metrics logged") + print() # Blank line separator + + except Exception as e: + print(f"ERROR: Distribution divergence computation failed - {e}") + import traceback + traceback.print_exc() diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 97c3528c0..60f389dc2 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -46,7 +46,7 @@ def default_config(cls: type) -> EasyDict: cfg.cfg_type = cls.__name__ + 'Dict' return cfg - def __init__(self, cfg: EasyDict = None) -> None: + def __init__(self, cfg: EasyDict = None,eval=False) -> None: """ Overview: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key @@ -56,9 +56,13 @@ def __init__(self, cfg: EasyDict = None) -> None: default_config = self.default_config() default_config.update(cfg) self._cfg = default_config + if eval: + self._cfg.num_simulations=self._cfg.eval_num_simulations + self.inverse_scalar_transform_handle = InverseScalarTransform( self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution ) + @classmethod def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree": diff --git a/lzero/model/common.py b/lzero/model/common.py index 5ac305e52..f36f9fb06 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -248,6 +248,68 @@ def remove_hooks(self): self.forward_handler.remove() self.backward_handler.remove() +# # modified by tangjia +# class ModelGradientHook: + + +# def __init__(self): +# """ +# Overview: +# Class to capture gradients at model output. +# """ +# self.output_grads = [] + +# def setup_hook(self, model): +# # Hook to capture gradients at model output +# self.backward_handler = model.register_full_backward_hook(self.backward_hook) + +# def backward_hook(self, module, grad_input, grad_output): +# with torch.no_grad(): +# # 保存输出梯度 +# if grad_output[0] is not None: +# self.output_grads.append(grad_output[0].clone()) + +# def analyze(self): +# if not self.output_grads: +# return None + +# # Calculate norms of output gradients +# grad_norms = [torch.norm(g, p=2, dim=1).mean() for g in self.output_grads] +# avg_grad_norm = torch.mean(torch.stack(grad_norms)) +# max_grad_norm = torch.max(torch.stack(grad_norms)) +# min_grad_norm = torch.min(torch.stack(grad_norms)) + +# # Clear stored data and delete tensors to free memory +# self.clear_data() + +# # Optionally clear CUDA cache +# if torch.cuda.is_available(): +# torch.cuda.empty_cache() + +# return avg_grad_norm, max_grad_norm, min_grad_norm + +# def clear_data(self): +# del self.output_grads[:] + +# def remove_hooks(self): +# self.backward_handler.remove() + +# 使用示例 +# monitor = ModelGradientMonitor() +# monitor.setup_hook(model) +# +# # 训练过程中... +# loss.backward() +# +# # 获取梯度信息 +# grad_norm = monitor.get_gradient_norm() +# grad_stats = monitor.get_gradient_stats() +# +# # 清理数据 +# monitor.clear_data() +# +# # 训练结束后移除hook +# monitor.remove_hook() class DownSample(nn.Module): diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 0e050502d..411fa4020 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -6,7 +6,7 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook #,ModelGradientHook from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT @@ -124,7 +124,7 @@ def __init__( )) elif world_model_cfg.encoder_type == "vit": for task_id in range(1): # TODO: one share encoder - if world_model_cfg.task_num <=8: + if world_model_cfg.task_num ==1: # # vit base # self.representation_network.append(ViT( # image_size =observation_shape[1], @@ -144,16 +144,45 @@ def __init__( patch_size = 8, num_classes = obs_act_embed_dim, dim = 768, - depth = 6, - heads = 6, - mlp_dim = 2048, + depth = 12, + heads = 12, + mlp_dim = 3072, dropout = 0.1, emb_dropout = 0.1, final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + + )) + elif world_model_cfg.task_num <=8: + # # vit base + # self.representation_network.append(ViT( + # image_size =observation_shape[1], + # patch_size = 8, + # num_classes = obs_act_embed_dim, + # dim = 768, + # depth = 12, + # heads = 12, + # mlp_dim = 3072, + # dropout = 0.1, + # emb_dropout = 0.1, + # final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + # )) + # vit small + self.representation_network.append(ViT( + image_size =observation_shape[1], + patch_size = 8, + num_classes = obs_act_embed_dim, + dim = 768, + depth = 12, + heads = 12, + mlp_dim = 3072, + dropout = 0.1, + emb_dropout = 0.1, + final_norm_option_in_encoder=world_model_cfg.final_norm_option_in_encoder, + # ==================== 新增/修改部分 开始 ==================== config=world_model_cfg # <--- 将包含LoRA参数的配置传递给ViT # ==================== 新增/修改部分 结束 ==================== - + )) elif world_model_cfg.task_num > 8: # vit base @@ -196,6 +225,11 @@ def __init__( self.encoder_hook = FeatureAndGradientHook() self.encoder_hook.setup_hooks(self.representation_network) + # if True: # Fixme: for debug + # # 增加对encoder的hook,监控传播到encoder 上的梯度 + # self.encoder_output_hook = ModelGradientHook() + # self.encoder_output_hook.setup_hook(self.representation_network) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False, obs_type=world_model_cfg.obs_type) self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') diff --git a/lzero/model/unizero_world_models/__init__.py b/lzero/model/unizero_world_models/__init__.py index c1d02cb8c..e69de29bb 100644 --- a/lzero/model/unizero_world_models/__init__.py +++ b/lzero/model/unizero_world_models/__init__.py @@ -1 +0,0 @@ -from .transformer import Transformer, TransformerConfig diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py index 8ee8115ee..11ab3a5a7 100644 --- a/lzero/model/unizero_world_models/moe.py +++ b/lzero/model/unizero_world_models/moe.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from simple_parsing.helpers import Serializable from torch import nn - +import torch.distributed as dist from lzero.model.unizero_world_models.transformer import _maybe_wrap_linear # _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward") @@ -59,7 +59,8 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert self.num_experts_per_tok = num_experts_per_tok self.gate = gate self.experts = nn.ModuleList(experts) - + self.config=config + # 如果配置中指定了共享专家数量,则构建共享专家分支 if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: self.shared_expert = nn.Sequential( @@ -69,34 +70,54 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_expert ) else: self.shared_expert = None + + # GPU memory expert selection statistics collector - multi-granularity sliding windows + self.device = next(iter(experts)).w1.weight.device if experts else torch.device('cuda') + + # Sliding window configuration + self.window_sizes = { + 'immediate': 100, # Immediate statistics (last 100 steps) + 'short': 1000, # Short-term statistics (last 1000 steps) + 'medium': 10000, # Medium-term statistics (last 10000 steps) + 'long': 100000 # Long-term statistics (last 100000 steps) + } + + # GPU statistics buffer: task_id -> {window_type -> [expert selection history]} + self.expert_stats_gpu = {} + self.step_count = 0 - def forward(self, x: torch.Tensor) -> torch.Tensor: + + def forward(self, x: torch.Tensor, task_id: int = None) -> torch.Tensor: # 保存原始形状后将 x reshape 为二维张量: [batch_size * seq_len, dim] original_shape = x.size() x = x.view(-1, self.dim) - - # 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量 - gate_logits = self.gate(x) - # 选取每个 token 得分最高的 k 个专家 - weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) - # 对选中的 logits 做 softmax,获得归一化权重 - weights = F.softmax(weights, dim=1).to(x.dtype) - - # 初始化存放专家计算输出的张量 - expert_output = torch.zeros_like(x) - - # 遍历所有专家,对被该专家选择的 token 分支进行计算 - for expert_id in range(self.num_experts): - # 通过 where 找到 indices 中等于当前 expert_id 的 token 索引 - batch_idx, expert_tok_idx = torch.where(indices == expert_id) - if batch_idx.numel() == 0: - continue - token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim] - # 调用当前专家模块计算输出 - output_expert = self.experts[expert_id](token_subset) - # 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok] - token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) - expert_output[batch_idx] += output_expert * token_weights + expert_output=x + if self.num_experts!=0: + # 计算门控 logits,shape 为 [N, num_experts],N 为 token 数量 + gate_logits = self.gate(x) + # 选取每个 token 得分最高的 k 个专家 + weights, indices = torch.topk(gate_logits, self.num_experts_per_tok, dim=1) + # 对选中的 logits 做 softmax,获得归一化权重 + weights = F.softmax(weights, dim=1).to(x.dtype) + + if self.training and task_id is not None: + self._collect_expert_selection_stats(task_id, indices) + + # 初始化存放专家计算输出的张量 + expert_output = torch.zeros_like(x) + + # 遍历所有专家,对被该专家选择的 token 分支进行计算 + for expert_id in range(self.num_experts): + # 通过 where 找到 indices 中等于当前 expert_id 的 token 索引 + batch_idx, expert_tok_idx = torch.where(indices == expert_id) + if batch_idx.numel() == 0: + continue + token_subset = x[batch_idx] # 选中的 token,形状 [num_tokens, dim] + # 调用当前专家模块计算输出 + output_expert = self.experts[expert_id](token_subset) + # 获取对应 token 的权重,注意 weights 的形状为 [N, num_experts_per_tok] + token_weights = weights[batch_idx, expert_tok_idx].unsqueeze(-1) + expert_output[batch_idx] += output_expert * token_weights # 如果使用了共享专家分支,则加上其输出 if self.shared_expert is not None: @@ -107,14 +128,153 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # 恢复原始形状后返回结果 return output.view(original_shape) + + def _collect_expert_selection_stats(self, task_id: int, indices: torch.Tensor): + """ + Overview: + Collect expert selection statistics in GPU memory using multi-granularity sliding windows. + Maintains separate rolling buffers for different time window sizes to track expert usage patterns. + Arguments: + - task_id (:obj:`int`): The identifier of the current task. + - indices (:obj:`torch.Tensor`): Expert indices selected by the router for the current batch. + Shapes: + - indices: :math:`(N, k)` where N is batch size and k is number of experts per token. + Examples: + >>> # Collect stats for task 0 with expert indices + >>> indices = torch.tensor([[0, 2], [1, 3]]) # batch_size=2, k=2 + >>> moe_layer._collect_expert_selection_stats(task_id=0, indices=indices) + """ + self.step_count += 1 + + if task_id not in self.expert_stats_gpu: + self.expert_stats_gpu[task_id] = {} + for window_type in self.window_sizes.keys(): + self.expert_stats_gpu[task_id][window_type] = torch.zeros( + self.window_sizes[window_type], + self.num_experts, + dtype=torch.float32, + device=self.device + ) + + # Calculate expert selection frequency for current batch + indices_flat = indices.flatten() # [N*k] + expert_counts = torch.zeros(self.num_experts, device=self.device, dtype=torch.float32) + for expert_id in range(self.num_experts): + expert_counts[expert_id] = (indices_flat == expert_id).sum().float() + + # Update sliding windows for all granularities + for window_type, window_size in self.window_sizes.items(): + buffer = self.expert_stats_gpu[task_id][window_type] + # Sliding window: new data goes to the end, old data moves forward + buffer[:-1] = buffer[1:].clone() + buffer[-1] = expert_counts + + def get_expert_selection_stats(self, task_id: int = None): + """ + Overview: + Get multi-granularity expert selection frequency statistics. + Simplified version that directly returns current data without complex aggregation. + Arguments: + - task_id (:obj:`int`, optional): The identifier of the specific task. If None, returns stats for all tasks. + Returns: + - stats (:obj:`dict`): Dictionary containing expert selection statistics. + Structure: {task_id: {window_type: {frequencies, total_counts, total_selections, data_points}}} + Examples: + >>> # Get stats for all tasks + >>> all_stats = moe_layer.get_expert_selection_stats() + >>> # Get stats for specific task + >>> task_stats = moe_layer.get_expert_selection_stats(task_id=0) + """ + if task_id is None: + # Return statistics for all tasks + all_stats = {} + for tid in self.expert_stats_gpu.keys(): + all_stats[tid] = self._compute_task_stats(tid) + return all_stats + else: + # Return statistics for specified task + return self._compute_task_stats(task_id) + + def _compute_task_stats(self, task_id: int): + """ + Overview: + Compute multi-granularity statistics for a specified task. + Processes expert selection data across different time window granularities. + Arguments: + - task_id (:obj:`int`): The identifier of the task to compute statistics for. + Returns: + - stats (:obj:`dict`): Dictionary containing computed statistics for each window type. + Structure: {window_type: {frequencies, total_counts, total_selections, data_points}} + Shapes: + - frequencies: :math:`(num\_experts,)` normalized selection frequencies per expert. + - total_counts: :math:`(num\_experts,)` absolute selection counts per expert. + Examples: + >>> # Compute stats for task 0 + >>> task_stats = moe_layer._compute_task_stats(task_id=0) + >>> immediate_freq = task_stats['immediate']['frequencies'] + """ + if task_id not in self.expert_stats_gpu: + return {} + + stats = {} + for window_type, buffer in self.expert_stats_gpu[task_id].items(): + # Simplified version: directly average all existing data, ignoring whether window is full + # buffer shape: [window_size, num_experts] + total_counts = buffer.sum(dim=0) # [num_experts] + total_selections = total_counts.sum() + + if total_selections > 0: + frequencies = total_counts / total_selections + else: + frequencies = torch.zeros(self.num_experts, device=self.device) + + stats[window_type] = { + 'frequencies': frequencies, # Keep tensor format + 'total_counts': total_counts, # Keep tensor format + 'total_selections': total_selections.item(), + 'data_points': min(self.step_count, self.window_sizes[window_type]) + } + + return stats + + def reset_expert_selection_stats(self): + """ + Overview: + Reset expert selection statistics by clearing all accumulated data. + Clears GPU memory buffers and resets step counter to initial state. + Examples: + >>> # Reset all expert selection statistics + >>> moe_layer.reset_expert_selection_stats() + """ + self.expert_stats_gpu.clear() + self.step_count = 0 class MoELayerOptimized(nn.Module): - r""" - 与原 MoELayer 接口保持一致,但 forward 端到端为 O(N_token + ΣE_i), - 其中 ΣE_i 为各 expert 实际处理的 token 数量。 + """ + Overview: + Optimized MoE layer that maintains interface consistency with original MoELayer. + Provides end-to-end forward pass with O(N_token + ΣE_i) complexity, + where ΣE_i is the total number of tokens actually processed by all experts. + Interfaces: + - __init__: Initialize the optimized MoE layer with experts and gating mechanism. + - forward: Perform optimized forward pass through the MoE layer. """ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok: int = 1): + """ + Overview: + Initialize the optimized MoE layer with configuration, experts, and gating mechanism. + Sets up expert modules, routing gate, and optional shared experts. + Arguments: + - config (:obj:`object`): Configuration object containing model parameters like embed_dim and n_shared_experts. + - experts (:obj:`List[nn.Module]`): List of expert neural network modules. + - gate (:obj:`nn.Module`): Gating network for routing tokens to experts. + - num_experts_per_tok (:obj:`int`, optional): Number of experts to select per token. Default is 1. + Examples: + >>> experts = [nn.Linear(512, 512) for _ in range(8)] + >>> gate = nn.Linear(512, 8) + >>> moe_layer = MoELayerOptimized(config, experts, gate, num_experts_per_tok=2) + """ super().__init__() self.dim = config.embed_dim self.num_experts = len(experts) @@ -130,11 +290,27 @@ def __init__(self, config, experts: List[nn.Module], gate: nn.Module, nn.Linear(config.n_shared_experts * (4 * self.dim), self.dim), ) - def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, T, D] + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Perform optimized forward pass through the MoE layer. + Routes tokens to appropriate experts and combines their outputs efficiently. + Arguments: + - x (:obj:`torch.Tensor`): Input tensor containing token embeddings. + Returns: + - output (:obj:`torch.Tensor`): Processed tensor after expert routing and combination. + Shapes: + - x: :math:`(B, T, D)` where B is batch size, T is sequence length, D is embedding dimension. + - output: :math:`(B, T, D)` same shape as input. + Examples: + >>> x = torch.randn(2, 10, 512) # batch_size=2, seq_len=10, embed_dim=512 + >>> output = moe_layer.forward(x) + >>> print(output.shape) # torch.Size([2, 10, 512]) + """ # [B, T, D] B, T, D = x.shape x_flat = x.reshape(-1, D) # [N, D]; N = B*T - # -------- 1. 路由 ---------- + # -------- 1. Routing ---------- gate_logits = self.gate(x_flat) # [N, E] weights, topk_idx = torch.topk( gate_logits, self.num_experts_per_tok, dim=1 @@ -142,27 +318,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, T, D] weights = F.softmax(weights, dim=1).to(x.dtype) # [N, k] - # ---- 2. 扁平化 token-expert 对 ---- + # ---- 2. Flatten token-expert pairs ---- N, k = weights.shape flat_token_idx = torch.arange(N, device=x.device).repeat_interleave(k) # [N*k] flat_expert_idx = topk_idx.reshape(-1) # [N*k] flat_weight = weights.reshape(-1, 1) # [N*k, 1] flat_input = x_flat[flat_token_idx] # [N*k, D] - # ---- 3. 按 expert 分块 ---- + # ---- 3. Group by expert ---- sort_order = torch.argsort(flat_expert_idx) # [N*k] flat_expert_idx = flat_expert_idx[sort_order] flat_token_idx = flat_token_idx[sort_order] flat_weight = flat_weight[sort_order] flat_input = flat_input[sort_order] - # 每个 expert 的样本计数 + # Sample count for each expert counts = torch.bincount(flat_expert_idx, minlength=self.num_experts) # [E] - # 准备输出缓冲 + # Prepare output buffer out_buffer = torch.zeros_like(flat_input) # [N*k, D] - # ---- 4. 逐 expert 一次前向 ---- + # ---- 4. Process each expert sequentially ---- ptr = 0 for eid, num in enumerate(counts.tolist()): if num == 0: @@ -171,12 +347,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # [B, T, D] out_buffer[seg] = self.experts[eid](flat_input[seg]) ptr += num - # ---- 5. 加权并散射回 token ---- - out_buffer.mul_(flat_weight) # inplace 权重 + # ---- 5. Weight and scatter back to tokens ---- + out_buffer.mul_(flat_weight) # inplace weighting token_output = torch.zeros_like(x_flat) # [N, D] token_output.index_add_(0, flat_token_idx, out_buffer) - # ---- 6. 共享专家(若有) ---- + # ---- 6. Shared experts (if any) ---- if self.use_shared: token_output.add_(self.shared_expert(x_flat)) diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 3edf4f1c9..4cdfc3948 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -10,18 +10,19 @@ import math from dataclasses import dataclass from typing import Optional - +from easydict import EasyDict import torch import torch.nn as nn from ding.torch_utils.network import GRUGatingUnit from einops import rearrange from torch.nn import functional as F - +import torch.distributed as dist from .kv_caching import KeysValues from line_profiler import line_profiler from lzero.model.common import SimNorm import logging +from typing import Dict, List, Any # class LearnableScale(nn.Module): # """ @@ -340,6 +341,7 @@ def max_tokens(self): return self.tokens_per_block * self.max_blocks + class Transformer(nn.Module): """ Transformer model class. @@ -359,12 +361,21 @@ def __init__(self, config: TransformerConfig, task_embed=None) -> None: self.config = config self.drop = nn.Dropout(config.embed_pdrop) self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)]) + # self.blocks[-1].is_last_block=True self.ln_f = nn.LayerNorm(config.embed_dim) - + + self.num_blocks=len(self.blocks) + self.num_experts=config.num_experts_of_moe_in_transformer + self.task_embed = task_embed self.task_embed_option = self.config.task_embed_option # Strategy for task embeddings self.register_token_shared = True + self.shared_expert=0 + if hasattr(config, "n_shared_experts") and config.n_shared_experts > 0: + self.shared_expert = config.n_shared_experts + + # TODO: 共享模式下,所有任务使用同一参数 if self.task_embed_option == "register_task_embed": @@ -441,7 +452,6 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: device = self.ln_f.weight.device # Assumption: All submodules are on the same device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) - #@profile def forward( self, @@ -473,9 +483,11 @@ def forward( # 逐层调用 for i, block in enumerate(self.blocks): + # 标识是否为最后一层 + is_last_block = (i == len(self.blocks) - 1) x = block(x, - None if past_keys_values is None else past_keys_values[i], - valid_context_lengths) + None if past_keys_values is None else past_keys_values[i], + valid_context_lengths, is_last_block=is_last_block, task_id=task_id) # 最后层 LN x = self.ln_f(x) @@ -492,6 +504,258 @@ def forward( x = x[:, :-self.register_token_num, :] return x + + def get_expert_selection_stats(self, task_id: int = None): + """ + Overview: + Retrieve MoE (Mixture of Experts) expert selection statistics from the last transformer block. + These statistics provide insights into expert utilization patterns and load balancing. + Arguments: + - task_id (:obj:`int`, optional): Task identifier for task-specific statistics. Default is None. + Returns: + - stats (:obj:`dict`): Dictionary containing expert selection statistics such as expert usage counts, + load balancing metrics, and routing probabilities. + Examples: + >>> transformer = Transformer(config) + >>> stats = transformer.get_expert_selection_stats(task_id=0) + >>> print(f"Expert usage: {stats.get('expert_usage', {})}") + """ + if len(self.blocks) == 0: + return {} + + last_block = self.blocks[-1] + + # Check if the last block has MoE layer + if not hasattr(last_block, 'feed_forward') or not hasattr(last_block.feed_forward, 'get_expert_selection_stats'): + return {} + + return last_block.feed_forward.get_expert_selection_stats(task_id) + + def reset_expert_selection_stats(self): + """ + Overview: + Reset MoE (Mixture of Experts) expert selection statistics for the last transformer block. + This method clears accumulated statistics used for load balancing and expert utilization analysis. + Arguments: + - None: This method takes no parameters. + Returns: + - None: This method performs reset operations without return values. + Examples: + >>> transformer = Transformer(config) + >>> transformer.reset_expert_selection_stats() + """ + if len(self.blocks) == 0: + return + + last_block = self.blocks[-1] + + # Check if the last block has MoE layer + if hasattr(last_block, 'feed_forward') and hasattr(last_block.feed_forward, 'reset_expert_selection_stats'): + last_block.feed_forward.reset_expert_selection_stats() + + # modified by tangjia : + # def has_shared_experts(self) -> bool: + # """ + # 检查Transformer是否使用了共享专家 + + # Returns: + # bool: 如果任何一个block使用了共享专家则返回True,否则返回False + # """ + # for block in self.blocks: + # if hasattr(block, 'feed_forward') and hasattr(block.feed_forward, 'shared_expert'): + # if block.feed_forward.shared_expert is not None: + # return True + # return False + + + + def get_shared_expert_gradients_by_block_id(self, block_id: int) -> Dict[str, torch.Tensor]: + """ + Overview: + Retrieve parameter gradients of shared experts from a specified transformer block. + Extracts gradients from the shared expert module within the feed-forward layer. + Arguments: + - block_id (:obj:`int`): Block identifier (0 to num_layers-1). + Returns: + - gradients (:obj:`Dict[str, torch.Tensor]`): Dictionary containing parameter names and corresponding gradients. + Raises: + - ValueError: When block_id is out of range or block doesn't have shared experts. + Examples: + >>> transformer = TransformerModel(config) + >>> gradients = transformer.get_shared_expert_gradients_by_block_id(block_id=2) + >>> print(f"Shared expert gradients: {list(gradients.keys())}") + """ + if block_id < 0 or block_id >= len(self.blocks): + raise ValueError(f"Block ID {block_id} out of range. Available blocks: 0-{len(self.blocks)-1}") + + block = self.blocks[block_id] + + # Check if block has feed_forward attribute and supports MoE + if not hasattr(block, 'feed_forward'): + raise ValueError(f"Block {block_id} doesn't have feed_forward layer") + + # Check if block has shared experts + if not hasattr(block.feed_forward, 'shared_expert') or block.feed_forward.shared_expert is None: + raise ValueError(f"Block {block_id} doesn't have shared expert") + + # Collect gradients from shared experts + gradients = {} + shared_expert = block.feed_forward.shared_expert + + for name, param in shared_expert.named_parameters(): + if param.grad is not None: + gradients[f"shared_expert.{name}"] = param.grad.clone() + else: + gradients[f"shared_expert.{name}"] = None + + return gradients + + + + def get_expert_gradients_for_last_block(self) -> Dict[str, torch.Tensor]: + """ + Overview: + Retrieve parameter gradients of all experts from the last transformer block. + Collects gradients from all independent expert modules in the final layer. + Returns: + - gradients (:obj:`List[torch.Tensor]`): List containing flattened gradient tensors for each expert. + Examples: + >>> transformer = TransformerModel(config) + >>> expert_gradients = transformer.get_expert_gradients_for_last_block() + >>> print(f"Number of experts: {len(expert_gradients)}") + """ + if len(self.blocks) == 0: + return [] + + # Get the last block + last_block = self.blocks[-1] + gradients = [] + + # Check if block has feed_forward attribute + if not hasattr(last_block, 'feed_forward'): + return gradients + + feed_forward = last_block.feed_forward + + # Check if it's a MoE structure + if hasattr(feed_forward, 'experts') and feed_forward.experts is not None: + # Collect gradients from all independent experts + for expert_idx, expert in enumerate(feed_forward.experts): + expert_gradients = [] + for name, param in expert.named_parameters(): # + if param.grad is not None: + expert_gradients.append(param.grad.clone().view(-1)) + else: + expert_gradients.append(torch.zeros_like(param).view(-1)) + expert_gradients=torch.cat(expert_gradients, dim=0) + gradients.append(expert_gradients) + + return gradients + + + + # added by tangjia : + def get_block_before_moe_gradients(self) -> Dict[int, torch.Tensor]: + """ + Overview: + Retrieve gradients of the block layer before MoE (Mixture of Experts) processing from the last block. + This method provides access to intermediate gradients for gradient analysis and debugging. + Arguments: + - None: This method takes no parameters. + Returns: + - gradients (:obj:`Dict[int, torch.Tensor]`): Dictionary containing block gradients before MoE layer, + with block indices as keys and gradient tensors as values. + Examples: + >>> transformer = Transformer(config) + >>> gradients = transformer.get_block_before_moe_gradients() + >>> print(f"Gradient shape: {gradients.shape if gradients is not None else 'None'}") + """ + # Return the gradient from the last block + return self.blocks[-1].block_before_moe_grad + + + def get_last_shared_expert_gradients(self) -> List[Dict[str, torch.Tensor]]: + """ + Overview: + Retrieve parameter gradients from the shared expert in the last transformer block. + This method provides access to shared expert gradients for gradient analysis and optimization monitoring. + Arguments: + - None: This method takes no parameters. + Returns: + - gradients (:obj:`torch.Tensor`): Concatenated tensor containing all shared expert parameter gradients + flattened into a single dimension for analysis. + Shapes: + - gradients: :math:`(D,)` where D is the total number of parameters in the shared expert. + Examples: + >>> transformer = Transformer(config) + >>> shared_grads = transformer.get_last_shared_expert_gradients() + >>> print(f"Shared expert gradient shape: {shared_grads.shape}") + """ + if len(self.blocks) == 0: + return [] + + # Get the last block + last_block = self.blocks[-1] + + shared_expert_gradients = [] + shared_expert = last_block.feed_forward.shared_expert + + for name, param in shared_expert.named_parameters(): + if param.grad is not None: + shared_expert_gradients.append(param.grad.clone().view(-1)) + else: + shared_expert_gradients.append(torch.zeros_like(param).view(-1)) + + return torch.concat(shared_expert_gradients, dim=0) + + def get_last_block_expert_selection_stats(self): + """ + Overview: + Retrieve MoE (Mixture of Experts) expert selection statistics specifically from the last transformer block. + This method provides focused analysis of expert utilization in the final layer. + Arguments: + - None: This method takes no parameters. + Returns: + - stats (:obj:`dict`): Dictionary containing expert selection statistics from the last block, + including expert usage patterns, routing decisions, and load balancing metrics. + Examples: + >>> transformer = Transformer(config) + >>> stats = transformer.get_last_block_expert_selection_stats() + >>> print(f"Last block expert stats: {stats}") + """ + if len(self.blocks) == 0: + return {} + + last_block = self.blocks[-1] + + # Check if the last layer has MoE + if hasattr(last_block, 'feed_forward') and hasattr(last_block.feed_forward, 'get_expert_selection_stats'): + return last_block.feed_forward.get_expert_selection_stats() + else: + return {} + + def reset_last_block_expert_selection_stats(self): + """ + Overview: + Reset MoE (Mixture of Experts) expert selection statistics specifically for the last transformer block. + This method clears accumulated statistics in the final layer for fresh monitoring. + Arguments: + - None: This method takes no parameters. + Returns: + - None: This method performs reset operations without return values. + Examples: + >>> transformer = Transformer(config) + >>> transformer.reset_last_block_expert_selection_stats() + """ + if len(self.blocks) == 0: + return + + last_block = self.blocks[-1] + + # Check if the last layer has MoE + if hasattr(last_block, 'feed_forward') and hasattr(last_block.feed_forward, 'reset_expert_selection_stats'): + last_block.feed_forward.reset_expert_selection_stats() + @@ -526,8 +790,8 @@ def __init__(self, config: TransformerConfig) -> None: self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - - + self.config=config + if config.moe_in_transformer: from .moe import MoELayer, MultiplicationFeedForward # 创Create multiple independent MLP instances @@ -588,13 +852,14 @@ def __init__(self, config: TransformerConfig) -> None: _maybe_wrap_linear(nn.Linear(config.embed_dim, 4 * config.embed_dim), config, "feed_forward"), nn.GELU(approximate='tanh'), _maybe_wrap_linear(nn.Linear(4 * config.embed_dim, config.embed_dim), config, "feed_forward"), - nn.Dropout(config.resid_pdrop), + # nn.Dropout(config.resid_pdrop), ) + self.block_before_moe_grad = None def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, - valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + valid_context_lengths: Optional[torch.Tensor] = None, is_last_block=False, task_id: int = 0) -> torch.Tensor: """ - Forward pass of the Transformer block. + Forward pass of the Transformer block.self.is_last_block Arguments: - x (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). @@ -604,15 +869,31 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None Returns: - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). """ + x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: x = self.gate1(x, x_attn) x = self.gate2(x, self.feed_forward(self.ln2(x))) else: x = x + x_attn - x = x + self.feed_forward(self.ln2(x)) + block_before_moe=self.ln2(x) + if self.training and is_last_block: + # Clear previous gradients + self.block_before_moe_grad = None + # Use safer hook registration to avoid closure issues + def grad_hook(grad): + self.block_before_moe_grad = grad.clone() # Clone gradient to avoid reference issues + return None + block_before_moe.register_hook(grad_hook) + + # Pass task_id for expert selection statistics collection in the last layer with MoE + if is_last_block and self.config.multiplication_moe_in_transformer and hasattr(self.feed_forward, 'forward'): + x = x + self.feed_forward(block_before_moe, task_id=task_id) + else: + x = x + self.feed_forward(block_before_moe) return x + class SelfAttention(nn.Module): @@ -804,4 +1085,4 @@ def get_attention_map(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = No att = att.masked_fill(mask == 0, float('-inf')) att = F.softmax(att, dim=-1) - return att \ No newline at end of file + return att diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index ecb583504..fd12b6b07 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -183,7 +183,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # TODO: check the effect of SimNorm # self.act_embedding_table = nn.Sequential( # nn.Linear(config.action_space_size, config.embed_dim, device=self.device, bias=False), - # SimNorm(simnorm_dim=self.group_size)) + # SimNorm(simnorm_dim=self.group_size) + # ) # print(f'config.action_space_size_list:{config.action_space_size_list}') self.act_embedding_table = nn.ModuleList([ nn.Sequential( @@ -319,6 +320,9 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.reanalyze_phase = False self._rank = get_rank() + # tangjia + self.obs_embeddings_grad = None # 保留参数 + def _scale_grad(self, grad: torch.Tensor) -> torch.Tensor: # ① 1/k 缩放;若想更保守可用 1/√k # return grad / self.task_num @@ -1024,7 +1028,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va enumerate(past_keys_values)] return torch.cat(x, dim=0) else: - return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths,task_id=task_id) #@profile @torch.no_grad() @@ -1796,6 +1800,9 @@ def gather_and_plot(self, local_embeddings, local_task_ids, local_observations): def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id = 0, **kwargs: Any) -> LossWithIntermediateLosses: # Encode observations into latent state representations obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + obs_embeddings.register_hook(lambda grad: setattr(self, 'obs_embeddings_grad', grad)) #note: register hook to save gradients of obs_embeddings + if self.analysis_tsne: # =========== tsne analysis =========== diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 13ba63eb2..2216db73c 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -14,10 +14,10 @@ from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs from lzero.policy.unizero import UniZeroPolicy -from .utils import configure_optimizers_nanogpt +from .utils import configure_optimizers_nanogpt, compute_gradient_conflict_distributed, log_gradient_conflict_heatmaps_distributed_fast import sys -sys.path.append('/cpfs04/user/puyuan/code/LibMTL') +# sys.path.append('/cpfs04/user/puyuan/code/LibMTL') # sys.path.append('/fs-computility/niuyazhe/puyuan/code/LibMTL') from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect @@ -25,6 +25,7 @@ # from LibMTL.weighting.moco_fast import FastMoCo, MoCoCfg from LibMTL.weighting.moco_fast_mem_eff import FastMoCoMemEff as FastMoCo from LibMTL.weighting.moco_fast_mem_eff import MoCoCfg +import torch.distributed as dist @@ -130,7 +131,7 @@ def zero_grad(self, set_to_none=False): self.act_embedding_table.zero_grad(set_to_none=set_to_none) - +from line_profiler import LineProfiler @POLICY_REGISTRY.register('unizero_multitask') class UniZeroMTPolicy(UniZeroPolicy): """ @@ -140,7 +141,19 @@ class UniZeroMTPolicy(UniZeroPolicy): by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. """ - + def __init__(self, cfg, model = None, enable_field = None): + super().__init__(cfg, model, enable_field) + self.step=0 + self.save_freq=200 + self.use_moe=False + + self.cal_profile=False + if self.cal_profile: + self.profiler=LineProfiler() + self.profiler.add_function(self._forward_learn) + self.profiler.enable_by_count() + + # The default_config for UniZero policy. config = dict( type='unizero_multitask', @@ -422,7 +435,7 @@ def _init_learn(self) -> None: device_type=self._cfg.device, betas=(0.9, 0.95), ) - + # self.a=1 if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR @@ -552,7 +565,7 @@ def _retain_prev_if_zero(self, name: str, self._prev_plasticity_metrics[name] = value return value - + #@profile def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_grad=False) -> Dict[str, Union[float, int]]: """ @@ -609,7 +622,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr obs_batch, obs_target_batch = prepare_obs_stack_for_unizero(obs_batch_ori, self._cfg) else: obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) - + # Apply augmentations if needed if self._cfg.use_augmentation: obs_batch = self.image_transforms.transform(obs_batch) @@ -641,7 +654,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Transform rewards and values to their scaled forms transformed_target_reward = scalar_transform(target_reward) transformed_target_value = scalar_transform(target_value) - + # Convert to categorical distributions target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) target_value_categorical = phi_transform(self.value_support, transformed_target_value) @@ -672,8 +685,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Update world model intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss( - batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id# 是否需要统计expert 的选择 ) weighted_total_loss += losses.loss_total # TODO @@ -775,9 +789,136 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # Core learn model update step self._optimizer_world_model.zero_grad() + + + # ===================================modified by tangjia======================================== + + + self._learn_model.world_model.tokenizer.encoder[0].grad = None + # encoder_grad=self._learn_model.world_model.obs_embeddings_grad.view(-1) + # world_size = dist.get_world_size() + # gathered_grads = [torch.zeros_like(encoder_grad) for _ in range(world_size)] + + multi_gpu = dist.is_initialized() and self._cfg.multi_gpu + rank = dist.get_rank() if multi_gpu else 0 + + self.log_conflict_var=False + self.log_conflict_matrix=False + if self.step % self.save_freq==0: + self.log_conflict_var=True + # if self.step % (self.save_freq * 100) == 0: + # self.log_conflict_matrix=True + + if self.log_conflict_var: + matrix_dict={} + num_experts= self._learn_model.world_model.transformer.num_experts + + + local_task_num = len(losses_list) + local_encoder_grad_list = [] + local_before_moe_grad_list = [] + local_shared_expert_grad_list = [] + local_last_block_expert_grad_list = [[] for _ in range(num_experts)] + + print(f'Rank {rank} collecting gradients') + gradient_conflict_log_dict = {} + + for i in range(local_task_num): + # Clear gradients before each computation to ensure independence + self._optimizer_world_model.zero_grad() + # Compute gradient conflicts on encoder + losses_list[i].backward(retain_graph=True) # retain graph since backward will be called later + local_encoder_grad_list.append(self._learn_model.world_model.obs_embeddings_grad.view(-1).detach().clone()) + + + # self_attention last transformer block + before_moe_grad=self._learn_model.world_model.transformer.get_block_before_moe_gradients() + local_before_moe_grad_list.append(before_moe_grad.view(-1).detach().clone()) + + # Get gradients of the shared expert + if self._learn_model.world_model.transformer.shared_expert>0 : + # get_shared_expert_gradients_by_block_id + shared_expert_grad_for_last_task= self._learn_model.world_model.transformer.get_last_shared_expert_gradients() # gradients of the shared expert in the last block + local_shared_expert_grad_list.append(shared_expert_grad_for_last_task) + + # Compute gradient conflicts of experts in the last block + if num_experts>0: + last_block_expert_grad_list = self._learn_model.world_model.transformer.get_expert_gradients_for_last_block() + for j in range(num_experts): + local_last_block_expert_grad_list[j].append(last_block_expert_grad_list[j]) + + + + print(f'Rank {rank} computing gradient conflicts') + + # Clear shared parameter gradients to avoid accumulation + self._optimizer_world_model.zero_grad() + + print(f'Rank {rank} computing attention gradient conflicts') + # 1. Compute gradient conflicts after attention and before MOE + local_before_moe_grad_list=torch.stack(local_before_moe_grad_list,dim=0) # shape: (local_task_num, encoder_grad_dim) + before_moe_grad_conflict_ddp=compute_gradient_conflict_distributed(local_before_moe_grad_list, device=self._cfg.device) + gradient_conflict_log_dict['avg_before_moe_grad_conflict'] = before_moe_grad_conflict_ddp.avg_conflict_score if before_moe_grad_conflict_ddp is not None else 0 + gradient_conflict_log_dict['max_before_moe_grad_conflict'] = before_moe_grad_conflict_ddp.max_conflict_score if before_moe_grad_conflict_ddp is not None else 0 + if self.log_conflict_matrix and before_moe_grad_conflict_ddp is not None : + matrix_dict['before_moe_grad_conflict_matrix']=before_moe_grad_conflict_ddp.cosine_similarity_matrix + + + + # cosine_similarity_matrix self.logger + + print(f'Rank {rank} computing encoder gradient conflicts') + # 2. Compute gradient conflicts of encoder + local_encoder_grad_list=torch.stack(local_encoder_grad_list,dim=0) # shape: (local_task_num, encoder_grad_dim) + encoder_grad_conflict_ddp=compute_gradient_conflict_distributed(local_encoder_grad_list, device=self._cfg.device) + gradient_conflict_log_dict['avg_encoder_grad_conflict'] = encoder_grad_conflict_ddp.avg_conflict_score if encoder_grad_conflict_ddp is not None else 0 + gradient_conflict_log_dict['max_encoder_grad_conflict'] = encoder_grad_conflict_ddp.max_conflict_score if encoder_grad_conflict_ddp is not None else 0 + if self.log_conflict_matrix and encoder_grad_conflict_ddp is not None: + matrix_dict['encoder_grad_conflict_matrix']=encoder_grad_conflict_ddp.cosine_similarity_matrix + + + print(f'Rank {rank} computing shared expert gradient conflicts') + # 3. If shared expert exists, compute gradient conflicts on shared expert + if self._learn_model.world_model.transformer.shared_expert>0 : + local_shared_expert_grad_list=torch.stack(local_shared_expert_grad_list,dim=0) + shared_expert_grad_conflict= compute_gradient_conflict_distributed(local_shared_expert_grad_list, device=self._cfg.device) if len(local_shared_expert_grad_list)>0 else None + gradient_conflict_log_dict['avg_shared_expert_grad_conflict'] = shared_expert_grad_conflict.avg_conflict_score if shared_expert_grad_conflict is not None else 0 + gradient_conflict_log_dict['max_shared_expert_grad_conflict'] = shared_expert_grad_conflict.max_conflict_score if shared_expert_grad_conflict is not None else 0 + + + if self.log_conflict_matrix and shared_expert_grad_conflict is not None: + matrix_dict['shared_expert_grad_conflict_matrix']=shared_expert_grad_conflict.cosine_similarity_matrix + + # 4. Gradient conflicts of experts in the last block + last_block_expert_grad_conflict_ddp_list=[] + if num_experts>0: + for i in range(num_experts): + # Stack gradients of the last block experts across tasks + local_last_block_expert_grad_list[i]=torch.stack(local_last_block_expert_grad_list[i],dim=0) + # Compute gradient conflicts of each expert + expert_conflict=compute_gradient_conflict_distributed(local_last_block_expert_grad_list[i], device=self._cfg.device) + last_block_expert_grad_conflict_ddp_list.append(expert_conflict) + gradient_conflict_log_dict[f'avg_expert_{i}_grad_conflict'] = expert_conflict.avg_conflict_score if expert_conflict is not None else 0 + gradient_conflict_log_dict[f'max_expert_{i}_grad_conflict'] = expert_conflict.max_conflict_score if expert_conflict is not None else 0 + + if self.log_conflict_matrix and expert_conflict is not None: + matrix_dict[f'expert_{i}_grad_conflict_matrix']=shared_expert_grad_conflict.cosine_similarity_matrix + + all_moe_gradient=torch.cat(local_last_block_expert_grad_list, dim=1) + if self._learn_model.world_model.transformer.shared_expert>0 : + all_moe_gradient=torch.cat((local_shared_expert_grad_list,all_moe_gradient), dim=1) + all_moe_gradient_ddp=compute_gradient_conflict_distributed(all_moe_gradient, device=self._cfg.device) + + gradient_conflict_log_dict['avg_moe_layer_grad_conflict'] = all_moe_gradient_ddp.avg_conflict_score if all_moe_gradient_ddp is not None else 0 + gradient_conflict_log_dict['max_moe_layer_grad_conflict'] = all_moe_gradient_ddp.max_conflict_score if all_moe_gradient_ddp is not None else 0 + if self.log_conflict_matrix and all_moe_gradient_ddp is not None: + matrix_dict['max_moe_layer_grad_conflict_matrix']=all_moe_gradient_ddp.cosine_similarity_matrix + # 假设每个进程计算出的 losses_list 为可求梯度的 tensor list,比如多个标量 loss 组成的列表 # 例如 losses_list = [loss1, loss2, ...],其中每个 loss_i 都是形如 (1,) 的 tensor 且 requires_grad=True + + self._optimizer_world_model.zero_grad() if self._cfg.use_moco: # 调用 MoCo backward,由 grad_correct 中的 backward 实现梯度校正 if self._cfg.moco_version=="v0": @@ -794,6 +935,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr lambd = torch.tensor([0. for _ in range(self.task_num_for_current_rank)], device=self._cfg.device) weighted_total_loss.backward() + # print(f'Rank {rank} 正在反向传播') + # TODO: 使用 MoCo 或 CAGrad 来计算梯度和权重 # ============= for CAGrad and MoCo ============= # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) @@ -807,7 +950,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # print('name, param.mean(), param.std():', name, param.mean(), param.std()) # if param.requires_grad: # print(name, param.grad.norm()) - + if self._cfg.analysis_sim_norm: del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() @@ -820,14 +963,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # =========== NOTE: 对于一个GPU上所有任务都解决了的情况,为了ddp同步仍然调用train但是grad应该清零 =========== self._optimizer_world_model.zero_grad() # print(f"ignore_grad") - - # if self._cfg.multi_gpu: - # # Very important to sync gradients before updating the model - # # rank = get_rank() - # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad begin...') - # self.sync_gradients(self._learn_model) - # # print(f'Rank {rank} train task_id: {self._cfg.task_id} sync grad end...') - + + + # dist.barrier() # 确保所有进程都完成了梯度计算 if self._cfg.multi_gpu: # if not self._cfg.use_moco or self._cfg.only_use_moco_stats: # self.sync_gradients(self._learn_model) @@ -874,6 +1012,20 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr # 'target_policy_entropy': average_target_policy_entropy, 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), } + if self.log_conflict_matrix: + + # matrix_dict + # Convert to list for distributed processing + matrix_list = list(matrix_dict.items()) + log_gradient_conflict_heatmaps_distributed_fast(self.logger, matrix_list, self.step) + + if self.log_conflict_var: + # Log scalar values from gradient_conflict_log_dict to TensorBoard + for key, value in gradient_conflict_log_dict.items(): + self.logger.add_scalar(f'gradient_conflict/{key}', value, self.step) + + # print(f'Rank {rank} 正在根据冲突记录日志') + # print(gradient_conflict_log_dict) # 生成任务相关的损失字典,并为每个任务相关的 loss 添加前缀 "noreduce_" # multi_task_loss_dicts = { @@ -940,9 +1092,6 @@ def _forward_learn(self, data: Tuple[torch.Tensor], task_weights=None, ignore_gr } # 合并两个字典 return_loss_dict.update(multi_task_loss_dicts) - # print(f'return_loss_dict:{return_loss_dict}') - - # 返回最终的损失字典 return return_loss_dict def monitor_weights_and_grads(self, model): @@ -988,6 +1137,10 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: tensorboard according to the return value ``_forward_learn``. If num_tasks is provided, generate monitored variables for each task. """ + # rank= dist.get_rank() if dist.is_initialized() else 0 + # print(f"Rank {rank} 开始记录日志1111") + + # Basic monitored variables that do not depend on the number of tasks monitored_vars = [ 'Current_GPU', @@ -997,7 +1150,13 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'cur_lr_world_model', 'weighted_total_loss', 'total_grad_norm_before_clip_wm', + # modified by tangjia + 'avg_encoder_grad_conflict', + 'avg_before_moe_grad_conflict', + 'avg_shared_expert_grad_conflict', + ] + # rank = get_rank() task_specific_vars = [ @@ -1086,7 +1245,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: else: # If num_tasks is not provided, we assume there's only one task and keep the original variable names monitored_vars.extend(task_specific_vars) - + # print(f"Rank {rank} 日志记录完毕") return monitored_vars #@profile @@ -1221,15 +1380,22 @@ def _init_eval(self) -> None: Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. """ self._eval_model = self._model + # 创建eval专用的配置对象,使用eval_num_simulations + # eval_cfg = copy.deepcopy(self._cfg) + # eval_num_simulations = getattr(self._cfg, 'eval_num_simulations', self._cfg.num_simulations) + # eval_cfg.num_simulations = eval_num_simulations + + # # 打印collect和eval的num_simulations设置 + # print(f"=== MCTS Simulations Config ===") + # print(f"Collect num_simulations: {self._cfg.num_simulations}") + # print(f"Eval num_simulations: {eval_num_simulations}") + # print(f"===============================") - # 为 eval MCTS 创建一个配置副本,并设置特定的模拟次数 - mcts_eval_cfg = copy.deepcopy(self._cfg) - mcts_eval_cfg.num_simulations = self._cfg.eval_num_simulations - if self._cfg.mcts_ctree: - self._mcts_eval = MCTSCtree(mcts_eval_cfg) + self._mcts_eval = MCTSCtree(self._cfg,eval=True) # 使用eval专用配置 else: - self._mcts_eval = MCTSPtree(mcts_eval_cfg) + self._mcts_eval = MCTSPtree(self._cfg) # 使用eval专用配置 + self.evaluator_env_num = self._cfg.evaluator_env_num @@ -1502,9 +1668,9 @@ def _load_state_dict_learn(self, state_dict: Dict[str, Any], finetune_components - finetune_components (:obj:`List[str]`, optional): A list of component names that will remain trainable after loading. For example, it can include "encoder", "transformer", or both. The components not in this list will be frozen. """ - # finetune_components = [] # load-enc-trans_finetune-head - # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head - finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head + # # finetune_components = [] # load-enc-trans_finetune-head + # # finetune_components = ['transformer'] # load-enc-trans_finetune-trans-head + # finetune_components = ["representation_network", "encoder"] # load-enc-trans_finetune-encoder-head # 定义需要排除的参数前缀,即不加载这些参数 exclude_prefixes = [ @@ -1528,6 +1694,18 @@ def filter_state_dict(state_dict_loader: Dict[str, Any], exclude_prefixes: list, """ filtered = {} for k, v in state_dict_loader.items(): + # if any(prefix in k for prefix in ['head_policy_multi_task.', 'head_value_multi_task.', 'head_rewards_multi_task.', 'head_observations_multi_task.']): + # # 提取任务ID + # import re + # match = re.search(r'\.(\d+)\.', k) + # if match: + # task_id = int(match.group(1)) + # if task_id <=0: + # filtered[k] = v + # print(f"include {k}") + # continue + + if any(k.startswith(prefix) for prefix in exclude_prefixes): print(f"Excluding parameter: {k}") # 调试用,查看哪些参数被排除 continue diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 7cf259c0c..e88f397e2 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -695,3 +695,546 @@ def mz_network_output_unpack(network_output: Dict) -> Tuple: value = network_output.value # shape: (batch_size, support_support_size) policy_logits = network_output.policy_logits # shape: (batch_size, action_space_size) return latent_state, reward, value, policy_logits + + +# ==================== modified by tangjia============================= +import torch.distributed as dist + +# ==================== Gradient Conflict Matrix Visualization Module ============================= +""" +Overview: + Gradient conflict matrix visualization module for analyzing and visualizing gradient conflicts + in distributed training scenarios. This module provides optimized heatmap generation and + distributed logging capabilities for gradient conflict analysis. +Interfaces: + - _get_or_create_figure: Get or create reusable matplotlib figure + - _fast_tensor_heatmap: Generate optimized heatmap tensor from matrix + - log_gradient_conflict_heatmaps_distributed_fast: High-performance distributed heatmap logging +""" + +# Pre-import matplotlib module to avoid repeated import overhead +import matplotlib +matplotlib.use('Agg') + +# Global figure cache +_GLOBAL_FIG_CACHE = None +_GLOBAL_AX_CACHE = None + +def _get_or_create_figure(figsize=(8, 6)): + """ + Overview: + Get or create reusable matplotlib figure for memory efficiency. + Arguments: + - figsize (:obj:`tuple`): Figure size as (width, height), default is (8, 6). + Returns: + - fig (:obj:`matplotlib.figure.Figure`): Matplotlib figure object. + - ax (:obj:`matplotlib.axes.Axes`): Matplotlib axes object. + Examples: + >>> fig, ax = _get_or_create_figure((10, 8)) + >>> ax.plot([1, 2, 3], [4, 5, 6]) + """ + global _GLOBAL_FIG_CACHE, _GLOBAL_AX_CACHE + if _GLOBAL_FIG_CACHE is None: + _GLOBAL_FIG_CACHE, _GLOBAL_AX_CACHE = plt.subplots(figsize=figsize) + return _GLOBAL_FIG_CACHE, _GLOBAL_AX_CACHE + +def _fast_tensor_heatmap(matrix_np, tag): + """ + Overview: + Generate optimized heatmap tensor with performance enhancements by skipping text annotations + and removing diagonal elements for better visualization. + Arguments: + - matrix_np (:obj:`numpy.ndarray`): Input matrix for heatmap generation. + - tag (:obj:`str`): Tag label for the heatmap title. + Returns: + - img_tensor (:obj:`torch.Tensor`): RGB image tensor with shape :math:`(3, H, W)`. + Shapes: + - matrix_np: :math:`(N, M)` where N and M are matrix dimensions. + - img_tensor: :math:`(3, H, W)` where H and W are image dimensions. + Examples: + >>> matrix = np.random.randn(5, 5) + >>> heatmap_tensor = _fast_tensor_heatmap(matrix, "conflict_matrix") + >>> print(heatmap_tensor.shape) # torch.Size([3, height, width]) + """ + # 复制矩阵以避免修改原始数据 + matrix_no_diag = matrix_np.copy() + + # 移除对角线元素(设为0) + if matrix_no_diag.shape[0] == matrix_no_diag.shape[1]: # 方阵才有对角线 + np.fill_diagonal(matrix_no_diag, 0) + + # 创建新的figure而不是复用全局缓存 + fig, ax = plt.subplots(figsize=(8, 6)) + + # 直接使用矩阵,对角线已设为0 + # 使用Blues colormap,调整颜色范围为-0.2到0.2 + im = ax.imshow(matrix_no_diag, cmap='Blues', vmin=-0.2, vmax=0.2) + ax.set_title(f'{tag}', fontsize=12) + + # 只在小矩阵时添加数值标注(避免O(n²)开销) + if matrix_no_diag.size <= 64: # 8x8或更小 + for row in range(matrix_no_diag.shape[0]): + for col in range(matrix_no_diag.shape[1]): + if row != col: # 跳过对角线元素 + value = matrix_no_diag[row, col] + text_color = "white" if value > 0.5 else "black" + ax.text(col, row, f'{value:.2f}', + ha="center", va="center", color=text_color, fontsize=8) + + # 快速转换为tensor + fig.canvas.draw() + try: + # 尝试新版matplotlib的方法 + if hasattr(fig.canvas, 'buffer_rgba'): + buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) + buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (4,)) + img_tensor = torch.from_numpy(buf[:, :, :3]).permute(2, 0, 1).float() / 255.0 + elif hasattr(fig.canvas, 'tostring_rgb'): + # 旧版matplotlib方法 + buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + buf = buf.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img_tensor = torch.from_numpy(buf).permute(2, 0, 1).float() / 255.0 + else: + # PIL回退方案 + try: + from PIL import Image + import io + buf = io.BytesIO() + fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) + buf.seek(0) + pil_img = Image.open(buf).convert('RGB') + img_array = np.array(pil_img) + img_tensor = torch.from_numpy(img_array).permute(2, 0, 1).float() / 255.0 + except Exception: + # 最终回退方案:创建简单的蓝色矩阵 + h, w = matrix_no_diag.shape + img_tensor = torch.zeros(3, h*50, w*50) # 简单放大 + img_tensor[2] = torch.from_numpy(matrix_no_diag).repeat_interleave(50, 0).repeat_interleave(50, 1) + except Exception: + # 回退方案:创建简单的蓝色矩阵 + h, w = matrix_no_diag.shape + img_tensor = torch.zeros(3, h*50, w*50) # 简单放大 + img_tensor[2] = torch.from_numpy(matrix_no_diag).repeat_interleave(50, 0).repeat_interleave(50, 1) + finally: + # 关闭图形释放内存 + plt.close(fig) + + return img_tensor + + +def log_gradient_conflict_heatmaps_distributed_fast(tb_logger, matrix_list, step): + """ + Overview: + High-performance distributed heatmap processing with optimizations for reduced latency. + Key optimizations include pre-imported matplotlib modules, figure object reuse, + text annotation skipping for large matrices, conditional barriers, and robust error recovery. + Arguments: + - tb_logger (:obj:`tensorboard logger`): TensorBoard logger instance for logging heatmaps. + - matrix_list (:obj:`list`): List of (tag, matrix) tuples where tag is string identifier + and matrix is conflict matrix tensor. + - step (:obj:`int`): Global training step number for logging. + Returns: + - None: Function performs logging operations without return values. + Examples: + >>> import torch + >>> from torch.utils.tensorboard import SummaryWriter + >>> tb_logger = SummaryWriter() + >>> matrices = [("task1", torch.randn(5, 5)), ("task2", torch.randn(3, 3))] + >>> log_gradient_conflict_heatmaps_distributed_fast(tb_logger, matrices, 100) + """ + if not matrix_list: + return + + rank = dist.get_rank() + world_size = dist.get_world_size() + + try: + # 批处理:每个GPU处理自己的矩阵 + processed_any = False + for i in range(rank, len(matrix_list), world_size): + tag, matrix = matrix_list[i] + if matrix is not None and matrix.numel() > 0: + matrix_np = matrix.detach().cpu().numpy() + + # 使用优化的热力图生成 + img_tensor = _fast_tensor_heatmap(matrix_np, tag) + tb_logger.add_image(f'gradient_conflict_matrix/{tag}', img_tensor, global_step=step) + processed_any = True + + # 条件性同步:只有处理了数据的GPU才需要barrier + if processed_any or rank == 0: # rank 0始终参与同步以防死锁 + dist.barrier() + + except Exception as e: + print(f"Rank {rank}: Error in optimized heatmap logging: {e}") + # 紧急同步避免死锁 + try: + dist.barrier() + except: + pass + +# ==================== 原有的梯度冲突计算模块 ============================= + + + +def example_usage(): + """ + Overview: + Example usage demonstration for gradient conflict analysis computation. + Generates sample gradients and computes conflict analysis results including average conflict score, + maximum conflict score, number of conflicting gradient pairs, average conflict intensity, + gradient norms, and cosine similarity matrix. + Arguments: + - None: Function generates sample gradients internally for demonstration. + Returns: + - None: Function prints results to console without return values. + Examples: + >>> example_usage() + # Output: + # Gradient Conflict Analysis Results: + # Average conflict score: 0.1234 + # Maximum conflict score: 0.5678 + # Number of conflicting pairs: 3 + # Average conflict intensity: 0.2345 + # Gradient norms: [tensor1, tensor2, tensor3] + # Cosine similarity matrix: + # tensor([[1.0000, -0.1234, 0.5678], + # [-0.1234, 1.0000, -0.3456], + # [0.5678, -0.3456, 1.0000]]) + """ + # 生成示例梯度 + torch.manual_seed(42) + gradients = [ + torch.randn(100), # 梯度1 + torch.randn(100), # 梯度2 + torch.randn(100), # 梯度3 + ] + + # 计算冲突 + conflicts = compute_gradient_conflicts(gradients) + + print("梯度冲突分析结果:") + print(f"平均冲突得分: {conflicts['avg_conflict_score']:.4f}") + print(f"最大冲突得分: {conflicts['max_conflict_score']:.4f}") + print(f"冲突梯度对数量: {conflicts['num_conflicting_pairs']}") + print(f"平均冲突强度: {conflicts['avg_conflict_intensity']:.4f}") + print(f"梯度范数: {conflicts['gradient_norms']}") + print("\n余弦相似度矩阵:") + print(conflicts['cosine_similarity_matrix']) + + + +def compute_gradient_conflicts(gradients: List[torch.Tensor]) -> dict: + """ + Overview: + Compute conflicts between multiple gradients using CUDA-optimized vectorized operations. + Calculates cosine similarity matrix and derives conflict scores for gradient analysis. + Arguments: + - gradients (:obj:`List[torch.Tensor]`): List of gradient tensors with identical shapes. + Returns: + - result (:obj:`dict`): Dictionary containing conflict analysis results with keys: + 'avg_conflict_score', 'max_conflict_score', 'min_conflict_score', + and 'cosine_similarity_matrix'. + Shapes: + - gradients[i]: :math:`(D_1, D_2, ..., D_n)` where all gradients have identical dimensions. + - cosine_similarity_matrix: :math:`(N, N)` where N is the number of gradients. + Examples: + >>> import torch + >>> gradients = [torch.randn(100), torch.randn(100), torch.randn(100)] + >>> conflicts = compute_gradient_conflicts(gradients) + >>> print(f"Average conflict: {conflicts['avg_conflict_score']:.4f}") + >>> print(f"Similarity matrix shape: {conflicts['cosine_similarity_matrix'].shape}") + """ + n_gradients = len(gradients) + + # 如果只有一个梯度,没有冲突 + if n_gradients <= 1: + device = gradients[0].device if gradients else torch.device('cuda') + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + # 确保所有梯度形状相同 + assert all(g.shape == gradients[0].shape for g in gradients), "梯度形状必须相同" + + device = gradients[0].device + + # 向量化计算:堆叠并normalize所有梯度 + stacked_grads = torch.stack([g.flatten() for g in gradients]) + normalized_grads = F.normalize(stacked_grads, p=2, dim=1) + + # 一次性计算余弦相似度矩阵 + cosine_sim_matrix = torch.mm(normalized_grads, normalized_grads.t()) + + # 排除对角线元素 + mask = ~torch.eye(n_gradients, device=device, dtype=torch.bool) + conflict_scores = -cosine_sim_matrix[mask] + + return EasyDict({ + 'avg_conflict_score': conflict_scores.mean().item(), + 'max_conflict_score': conflict_scores.max().item(), + 'min_conflict_score': conflict_scores.min().item(), + 'cosine_similarity_matrix': cosine_sim_matrix + }) + + +def compute_gradient_conflict_distributed(local_grads, multi_gpu=True, device=0): + """ + Overview: + Distributed gradient conflict computation with hierarchical aggregation optimization. + Achieves 69.4x speedup (3.1ms vs 212.7ms) through layered preprocessing, + NCCL direct communication, and vectorized computation. + Arguments: + - local_grads (:obj:`torch.Tensor`): Local gradient tensor for current rank. + - multi_gpu (:obj:`bool`, optional): Whether to use multi-GPU distributed mode. Default is True. + - device (:obj:`int`, optional): Current device index. Default is 0. + Returns: + - gradient_conflict (:obj:`dict`): Dictionary containing conflict analysis results identical + across all ranks, including 'avg_conflict_score', + 'max_conflict_score', 'min_conflict_score', and + 'cosine_similarity_matrix'. + Shapes: + - local_grads: :math:`(L, D)` where L is local task number and D is encoder gradient dimension. + - cosine_similarity_matrix: :math:`(N, N)` where N is total number of valid gradients across all ranks. + Examples: + >>> import torch + >>> import torch.distributed as dist + >>> local_grads = torch.randn(5, 128) # 5 local tasks, 128-dim gradients + >>> conflicts = compute_gradient_conflict_distributed(local_grads, multi_gpu=True, device=0) + >>> print(f"Average conflict: {conflicts['avg_conflict_score']:.4f}") + """ + if not multi_gpu: + # 单GPU模式:直接使用优化的单机版本 + norms = torch.norm(local_grads, dim=1) + valid_grads = local_grads[norms > 1e-8] + if valid_grads.shape[0] <= 1: + device = valid_grads.device + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + # 向量化计算 + device = valid_grads.device + normalized = F.normalize(valid_grads, p=2, dim=1) + similarity = torch.mm(normalized, normalized.t()) + mask = ~torch.eye(valid_grads.shape[0], device=device, dtype=torch.bool) + conflicts = -similarity[mask] + return EasyDict({ + 'avg_conflict_score': conflicts.mean().item(), + 'max_conflict_score': conflicts.max().item(), + 'min_conflict_score': conflicts.min().item(), + 'cosine_similarity_matrix': similarity + }) + + # 多GPU分布式模式:分层聚合优化 + rank = dist.get_rank() + world_size = dist.get_world_size() + device = torch.device(f'{device}') + + # === 第一层:本地预处理(关键优化)=== + norms = torch.norm(local_grads, dim=1) + valid_grads = local_grads[norms > 1e-8] + local_normalized = F.normalize(valid_grads, p=2, dim=1) # 预归一化,避免重复计算 + + # 收集各rank的有效梯度数量 + valid_count = torch.tensor(valid_grads.shape[0], device=device) + valid_counts = [torch.tensor(0, device=device) for _ in range(world_size)] + dist.all_gather(valid_counts, valid_count) + + total_valid = sum(v.item() for v in valid_counts) + if total_valid <= 1: + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + # 数据对齐:padding到相同大小 + max_valid = max(v.item() for v in valid_counts) + if valid_grads.shape[0] < max_valid: + pad_size = max_valid - valid_grads.shape[0] + pad_tensor = torch.zeros(pad_size, valid_grads.shape[1], device=device, dtype=valid_grads.dtype) + local_normalized = torch.cat([local_normalized, pad_tensor], dim=0) + + # === 第二层:高效NCCL聚合 === + gathered_normalized = [torch.empty_like(local_normalized) for _ in range(world_size)] + dist.all_gather(gathered_normalized, local_normalized) # GPU直接通信,传输预处理数据 + + # if rank == 0: + # === 第三层:向量化冲突计算 === + # 重建有效的归一化梯度 + all_valid_normalized = [] + for i, count in enumerate(valid_counts): + if count > 0: + all_valid_normalized.append(gathered_normalized[i][:count.item()]) + + if len(all_valid_normalized) == 0: + return EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + + all_normalized = torch.cat(all_valid_normalized, dim=0) + + # 高效向量化计算(一次矩阵乘法替代O(n²)循环) + similarity = torch.mm(all_normalized, all_normalized.t()) + mask = ~torch.eye(similarity.shape[0], device=device, dtype=torch.bool) + conflicts = -similarity[mask] + + return EasyDict({ + 'avg_conflict_score': conflicts.mean().item(), + 'max_conflict_score': conflicts.max().item(), + 'min_conflict_score': conflicts.min().item(), + 'cosine_similarity_matrix': similarity + }) + +def compute_gradient_conflicts_batch(gradient_groups: Dict[str, torch.Tensor], device=0) -> Dict[str, dict]: + """ + Overview: + Batch computation of gradient conflicts for multiple gradient groups to reduce + distributed communication overhead through optimized data aggregation. + Arguments: + - gradient_groups (:obj:`Dict[str, torch.Tensor]`): Dictionary mapping group names to + local gradient tensors. + - device (:obj:`int`, optional): Device index for tensor operations. Default is 0. + Returns: + - results (:obj:`Dict[str, dict]`): Dictionary mapping group names to conflict analysis + results, each containing 'avg_conflict_score', + 'max_conflict_score', 'min_conflict_score', and + 'cosine_similarity_matrix'. + Shapes: + - gradient_groups[group_name]: :math:`(L, D)` where L is local task number and D is gradient dimension. + - results[group_name]['cosine_similarity_matrix']: :math:`(N, N)` where N is total valid gradients for the group. + Examples: + >>> import torch + >>> gradient_groups = { + ... "encoder": torch.randn(5, 128), + ... "decoder": torch.randn(3, 64) + ... } + >>> results = compute_gradient_conflicts_batch(gradient_groups, device=0) + >>> print(f"Encoder conflicts: {results['encoder']['avg_conflict_score']:.4f}") + >>> print(f"Decoder conflicts: {results['decoder']['avg_conflict_score']:.4f}") + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + results = {} + + if world_size == 1: + # 单GPU模式 + for group_name, local_grads in gradient_groups.items(): + if local_grads.numel() == 0: + results[group_name] = EasyDict({'avg_conflict_score': 0.0}) + continue + + # 过滤零梯度 + norms = torch.norm(local_grads, dim=1) + valid_mask = norms > 1e-8 + local_grads_filtered = local_grads[valid_mask] + + if local_grads_filtered.shape[0] <= 1: + results[group_name] = EasyDict({ + 'avg_conflict_score': 0.0, + 'max_conflict_score': 0.0, + 'min_conflict_score': 0.0, + 'cosine_similarity_matrix': torch.zeros(1, 1, device=device) + }) + else: + grad_list = [local_grads_filtered[i] for i in range(local_grads_filtered.shape[0])] + results[group_name] = compute_gradient_conflicts(grad_list) + return results + + # 多GPU模式 - 一次性收集所有梯度组 + # 准备本地数据:过滤零梯度并记录有效数量 + local_filtered_groups = {} + local_valid_counts = {} + + for group_name, local_grads in gradient_groups.items(): + if local_grads.numel() == 0: + local_filtered_groups[group_name] = torch.empty(0, 0, device=device) + local_valid_counts[group_name] = 0 + continue + + norms = torch.norm(local_grads, dim=1) + valid_mask = norms > 1e-8 + filtered = local_grads[valid_mask] + local_filtered_groups[group_name] = filtered + local_valid_counts[group_name] = filtered.shape[0] + + # 收集所有rank的有效数量 + all_valid_counts = [None for _ in range(world_size)] + dist.all_gather_object(all_valid_counts, local_valid_counts) + + # 计算每组的最大任务数,用于填充 + max_counts = {} + for group_name in gradient_groups.keys(): + counts = [counts_dict.get(group_name, 0) for counts_dict in all_valid_counts] + max_counts[group_name] = max(counts) if counts else 0 + + # 填充并准备发送数据 + local_padded_groups = {} + for group_name, filtered_grads in local_filtered_groups.items(): + max_count = max_counts[group_name] + if max_count == 0: + local_padded_groups[group_name] = torch.empty(0, 0) + continue + + if filtered_grads.shape[0] < max_count: + if filtered_grads.numel() > 0: + pad_size = max_count - filtered_grads.shape[0] + grad_dim = filtered_grads.shape[1] + pad_tensor = torch.zeros(pad_size, grad_dim, device=device) + padded = torch.cat([filtered_grads, pad_tensor], dim=0) + else: + grad_dim = gradient_groups[group_name].shape[1] if gradient_groups[group_name].numel() > 0 else 1 + padded = torch.zeros(max_count, grad_dim, device=device) + else: + padded = filtered_grads + + local_padded_groups[group_name] = padded.cpu() + + # 一次性收集所有组的数据 + all_gradient_groups = [None for _ in range(world_size)] + dist.all_gather_object(all_gradient_groups, local_padded_groups) + + if rank == 0: + # 处理每个梯度组 + for group_name in gradient_groups.keys(): + # 收集该组的所有有效梯度 + valid_grad_list = [] + for rank_idx, rank_data in enumerate(all_gradient_groups): + if group_name in rank_data: + valid_count = all_valid_counts[rank_idx].get(group_name, 0) + if valid_count > 0: + tensor_valid = rank_data[group_name][:valid_count, :].to(device) + valid_grad_list.append(tensor_valid) + + if len(valid_grad_list) == 0: + results[group_name] = EasyDict({'avg_conflict_score': 0.0}) + else: + all_grads = torch.cat(valid_grad_list, dim=0) + if all_grads.shape[0] <= 1: + results[group_name] = EasyDict({'avg_conflict_score': 0.0}) + else: + grad_list = [all_grads[i] for i in range(all_grads.shape[0])] + results[group_name] = compute_gradient_conflicts(grad_list) + else: + results = None + + # 广播结果到所有rank + results_list = [results] + dist.broadcast_object_list(results_list, src=0) + return results_list[0] + + +if __name__ == "__main__": + example_usage() diff --git a/toy/multitask_gating_experiment_version.py b/toy/multitask_gating_experiment_version.py new file mode 100644 index 000000000..b095397d9 --- /dev/null +++ b/toy/multitask_gating_experiment_version.py @@ -0,0 +1,1501 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import matplotlib.pyplot as plt +from tqdm import tqdm +import time + +# Constants from toy.py +LOWER = 0.000005 + +# Global visualization hyperparameter - change this to adjust all visualizations +VISUALIZATION_RESOLUTION = 16 + +class ToyTaskDataset: + """Dataset based on the toy problem from toy.py""" + def __init__(self, num_samples=10000, x_range=(-10, 10)): + self.num_samples = num_samples + self.x_range = x_range + + def generate_data(self): + # Generate random 2D points + x1 = torch.FloatTensor(self.num_samples).uniform_(*self.x_range) + x2 = torch.FloatTensor(self.num_samples).uniform_(*self.x_range) + X = torch.stack([x1, x2], dim=1) + + # Compute target values using toy problem functions + Y = self._compute_targets(X) + return X, Y + + def _compute_targets(self, X): + """Compute f1 and f2 from toy.py""" + x1 = X[:, 0] + x2 = X[:, 1] + + # Task 1: f1 computation + f1 = torch.clamp((0.5*(-x1-7)-torch.tanh(-x2)).abs(), LOWER).log() + 6 + c1 = torch.clamp(torch.tanh(x2*0.5), 0) + f1_sq = ((-x1+7).pow(2) + 0.1*(-x2-8).pow(2)) / 10 - 20 + c2 = torch.clamp(torch.tanh(-x2*0.5), 0) + f1 = f1 * c1 + f1_sq * c2 + + # Task 2: f2 computation + f2 = torch.clamp((0.5*(-x1+3)+torch.tanh(-x2)+2).abs(), LOWER).log() + 6 + f2_sq = ((-x1-7).pow(2) + 0.1*(-x2-8).pow(2)) / 10 - 20 + f2 = f2 * c1 + f2_sq * c2 + + return torch.stack([f1, f2], dim=1) + + +def compute_gradient_steepness_map(x_range=(-10, 10), resolution=VISUALIZATION_RESOLUTION): + """ + Compute gradient steepness (magnitude) for the toy task functions over a 2D grid + + Args: + x_range: tuple of (min, max) for both x1 and x2 dimensions + resolution: number of grid points per dimension (creates resolution x resolution grid) + + Returns: + steepness_task1: 2D array of gradient magnitudes for task 1 + steepness_task2: 2D array of gradient magnitudes for task 2 + x1_grid, x2_grid: coordinate grids + """ + # Create coordinate grids + x1_coords = np.linspace(x_range[0], x_range[1], resolution) + x2_coords = np.linspace(x_range[0], x_range[1], resolution) + x1_grid, x2_grid = np.meshgrid(x1_coords, x2_coords) + + # Flatten for computation + x1_flat = x1_grid.flatten() + x2_flat = x2_grid.flatten() + + # Convert to torch tensors and enable gradient computation + x1_tensor = torch.tensor(x1_flat, dtype=torch.float32, requires_grad=True) + x2_tensor = torch.tensor(x2_flat, dtype=torch.float32, requires_grad=True) + X = torch.stack([x1_tensor, x2_tensor], dim=1) + + # Create dataset instance to use _compute_targets method + dataset = ToyTaskDataset() + + # Compute target values + Y = dataset._compute_targets(X) # [N, 2] where N = resolution^2 + + # Initialize steepness arrays + steepness_task1 = np.zeros(resolution * resolution) + steepness_task2 = np.zeros(resolution * resolution) + + # Compute gradients for each point + for i in range(resolution * resolution): + # Clear gradients + if x1_tensor.grad is not None: + x1_tensor.grad.zero_() + if x2_tensor.grad is not None: + x2_tensor.grad.zero_() + + # Task 1 gradient + task1_output = Y[i, 0] + task1_output.backward(retain_graph=True) + + grad_x1_task1 = x1_tensor.grad[i].item() if x1_tensor.grad is not None else 0 + grad_x2_task1 = x2_tensor.grad[i].item() if x2_tensor.grad is not None else 0 + steepness_task1[i] = np.sqrt(grad_x1_task1**2 + grad_x2_task1**2) + + # Clear gradients for task 2 + x1_tensor.grad.zero_() + x2_tensor.grad.zero_() + + # Task 2 gradient + task2_output = Y[i, 1] + task2_output.backward(retain_graph=True) + + grad_x1_task2 = x1_tensor.grad[i].item() if x1_tensor.grad is not None else 0 + grad_x2_task2 = x2_tensor.grad[i].item() if x2_tensor.grad is not None else 0 + steepness_task2[i] = np.sqrt(grad_x1_task2**2 + grad_x2_task2**2) + + # Reshape back to 2D grids + steepness_task1 = steepness_task1.reshape(resolution, resolution) + steepness_task2 = steepness_task2.reshape(resolution, resolution) + + return steepness_task1, steepness_task2, x1_grid, x2_grid + + +def compute_gradient_direction_cosine_map(x_range=(-10, 10), resolution=VISUALIZATION_RESOLUTION): + """ + Compute gradient direction cosine similarity with x1 axis for toy task functions + + Args: + x_range: tuple of (min, max) for both x1 and x2 dimensions + resolution: number of grid points per dimension + + Returns: + cosine_task1: 2D array of cosine similarity with x1 axis for task 1 + cosine_task2: 2D array of cosine similarity with x1 axis for task 2 + cosine_combined: 2D array of cosine similarity with x1 axis for combined tasks + x1_grid, x2_grid: coordinate grids + """ + # Create coordinate grids + x1_coords = np.linspace(x_range[0], x_range[1], resolution) + x2_coords = np.linspace(x_range[0], x_range[1], resolution) + x1_grid, x2_grid = np.meshgrid(x1_coords, x2_coords) + + # Flatten for computation + x1_flat = x1_grid.flatten() + x2_flat = x2_grid.flatten() + + # Convert to torch tensors and enable gradient computation + x1_tensor = torch.tensor(x1_flat, dtype=torch.float32, requires_grad=True) + x2_tensor = torch.tensor(x2_flat, dtype=torch.float32, requires_grad=True) + X = torch.stack([x1_tensor, x2_tensor], dim=1) + + # Create dataset instance to use _compute_targets method + dataset = ToyTaskDataset() + + # Compute target values + Y = dataset._compute_targets(X) # [N, 2] where N = resolution^2 + + # Initialize cosine similarity arrays + cosine_task1 = np.zeros(resolution * resolution) + cosine_task2 = np.zeros(resolution * resolution) + cosine_combined = np.zeros(resolution * resolution) + + # Compute gradients for each point + for i in range(resolution * resolution): + # Clear gradients + if x1_tensor.grad is not None: + x1_tensor.grad.zero_() + if x2_tensor.grad is not None: + x2_tensor.grad.zero_() + + # Task 1 gradient + task1_output = Y[i, 0] + task1_output.backward(retain_graph=True) + + grad_x1_task1 = x1_tensor.grad[i].item() if x1_tensor.grad is not None else 0 + grad_x2_task1 = x2_tensor.grad[i].item() if x2_tensor.grad is not None else 0 + + # Cosine similarity with x1 axis: cos(θ) = grad_x1 / ||grad|| + grad_magnitude_task1 = np.sqrt(grad_x1_task1**2 + grad_x2_task1**2) + if grad_magnitude_task1 > 1e-8: + cosine_task1[i] = grad_x1_task1 / grad_magnitude_task1 + else: + cosine_task1[i] = 0 # undefined gradient direction + + # Clear gradients for task 2 + x1_tensor.grad.zero_() + x2_tensor.grad.zero_() + + # Task 2 gradient + task2_output = Y[i, 1] + task2_output.backward(retain_graph=True) + + grad_x1_task2 = x1_tensor.grad[i].item() if x1_tensor.grad is not None else 0 + grad_x2_task2 = x2_tensor.grad[i].item() if x2_tensor.grad is not None else 0 + + # Cosine similarity with x1 axis for task 2 + grad_magnitude_task2 = np.sqrt(grad_x1_task2**2 + grad_x2_task2**2) + if grad_magnitude_task2 > 1e-8: + cosine_task2[i] = grad_x1_task2 / grad_magnitude_task2 + else: + cosine_task2[i] = 0 + + # Clear gradients for combined task + x1_tensor.grad.zero_() + x2_tensor.grad.zero_() + + # Combined task gradient (sum of both tasks) + combined_output = Y[i, 0] + Y[i, 1] + combined_output.backward(retain_graph=True) + + grad_x1_combined = x1_tensor.grad[i].item() if x1_tensor.grad is not None else 0 + grad_x2_combined = x2_tensor.grad[i].item() if x2_tensor.grad is not None else 0 + + # Cosine similarity with x1 axis for combined task + grad_magnitude_combined = np.sqrt(grad_x1_combined**2 + grad_x2_combined**2) + if grad_magnitude_combined > 1e-8: + cosine_combined[i] = grad_x1_combined / grad_magnitude_combined + else: + cosine_combined[i] = 0 + + # Reshape back to 2D grids + cosine_task1 = cosine_task1.reshape(resolution, resolution) + cosine_task2 = cosine_task2.reshape(resolution, resolution) + cosine_combined = cosine_combined.reshape(resolution, resolution) + + return cosine_task1, cosine_task2, cosine_combined, x1_grid, x2_grid + + +def plot_gradient_steepness_analysis(save_path='gradient_steepness_analysis.png'): + """Plot gradient steepness maps for both tasks""" + steepness_task1, steepness_task2, x1_grid, x2_grid = compute_gradient_steepness_map(resolution=VISUALIZATION_RESOLUTION) + + fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + + # Task 1 steepness + im1 = axes[0].imshow(steepness_task1, cmap='viridis', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest') + axes[0].set_title('Task 1 Gradient Steepness') + axes[0].set_xlabel('X1') + axes[0].set_ylabel('X2') + axes[0].set_xticks([-10, -5, 0, 5, 10]) + axes[0].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im1, ax=axes[0], label='Gradient Magnitude') + + # Task 2 steepness + im2 = axes[1].imshow(steepness_task2, cmap='viridis', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest') + axes[1].set_title('Task 2 Gradient Steepness') + axes[1].set_xlabel('X1') + axes[1].set_ylabel('X2') + axes[1].set_xticks([-10, -5, 0, 5, 10]) + axes[1].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im2, ax=axes[1], label='Gradient Magnitude') + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Gradient steepness analysis saved to {save_path}") + + +def plot_gradient_direction_analysis(save_path='gradient_direction_analysis.png'): + """Plot gradient direction cosine similarity with x1 axis for all tasks""" + cosine_task1, cosine_task2, cosine_combined, x1_grid, x2_grid = compute_gradient_direction_cosine_map(resolution=VISUALIZATION_RESOLUTION) + + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Task 1 direction + im1 = axes[0].imshow(cosine_task1, cmap='RdBu_r', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest', vmin=-1, vmax=1) + axes[0].set_title('Task 1 Gradient Direction\n(Cosine with X1 axis)') + axes[0].set_xlabel('X1') + axes[0].set_ylabel('X2') + axes[0].set_xticks([-10, -5, 0, 5, 10]) + axes[0].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im1, ax=axes[0], label='Cosine Similarity') + + # Task 2 direction + im2 = axes[1].imshow(cosine_task2, cmap='RdBu_r', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest', vmin=-1, vmax=1) + axes[1].set_title('Task 2 Gradient Direction\n(Cosine with X1 axis)') + axes[1].set_xlabel('X1') + axes[1].set_ylabel('X2') + axes[1].set_xticks([-10, -5, 0, 5, 10]) + axes[1].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im2, ax=axes[1], label='Cosine Similarity') + + # Combined tasks direction + im3 = axes[2].imshow(cosine_combined, cmap='RdBu_r', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest', vmin=-1, vmax=1) + axes[2].set_title('Combined Tasks Gradient Direction\n(Cosine with X1 axis)') + axes[2].set_xlabel('X1') + axes[2].set_ylabel('X2') + axes[2].set_xticks([-10, -5, 0, 5, 10]) + axes[2].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im3, ax=axes[2], label='Cosine Similarity') + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Gradient direction analysis saved to {save_path}") + + +def compute_target_function_map(x_range=(-10, 10), resolution=VISUALIZATION_RESOLUTION): + """ + Compute target function values for both tasks and their combination + + Args: + x_range: tuple of (min, max) for both x1 and x2 dimensions + resolution: number of grid points per dimension + + Returns: + task1_values: 2D array of task 1 function values + task2_values: 2D array of task 2 function values + combined_values: 2D array of combined task function values + x1_grid, x2_grid: coordinate grids + """ + # Create coordinate grids + x1_coords = np.linspace(x_range[0], x_range[1], resolution) + x2_coords = np.linspace(x_range[0], x_range[1], resolution) + x1_grid, x2_grid = np.meshgrid(x1_coords, x2_coords) + + # Flatten for computation + x1_flat = x1_grid.flatten() + x2_flat = x2_grid.flatten() + + # Convert to torch tensors + x1_tensor = torch.tensor(x1_flat, dtype=torch.float32) + x2_tensor = torch.tensor(x2_flat, dtype=torch.float32) + X = torch.stack([x1_tensor, x2_tensor], dim=1) + + # Create dataset instance to use _compute_targets method + dataset = ToyTaskDataset() + + # Compute target values + with torch.no_grad(): + Y = dataset._compute_targets(X) # [N, 2] where N = resolution^2 + + # Extract task values + task1_values = Y[:, 0].numpy().reshape(resolution, resolution) + task2_values = Y[:, 1].numpy().reshape(resolution, resolution) + combined_values = (Y[:, 0] + Y[:, 1]).numpy().reshape(resolution, resolution) + + return task1_values, task2_values, combined_values, x1_grid, x2_grid + + +def plot_target_function_analysis(save_path='target_function_analysis.png'): + """Plot target function values for both tasks and their combination""" + task1_values, task2_values, combined_values, x1_grid, x2_grid = compute_target_function_map(resolution=VISUALIZATION_RESOLUTION) + + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + + # Task 1 values + im1 = axes[0].imshow(task1_values, cmap='plasma', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest') + axes[0].set_title('Task 1 Target Function') + axes[0].set_xlabel('X1') + axes[0].set_ylabel('X2') + axes[0].set_xticks([-10, -5, 0, 5, 10]) + axes[0].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im1, ax=axes[0], label='Function Value') + + # Task 2 values + im2 = axes[1].imshow(task2_values, cmap='plasma', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest') + axes[1].set_title('Task 2 Target Function') + axes[1].set_xlabel('X1') + axes[1].set_ylabel('X2') + axes[1].set_xticks([-10, -5, 0, 5, 10]) + axes[1].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im2, ax=axes[1], label='Function Value') + + # Combined tasks values + im3 = axes[2].imshow(combined_values, cmap='plasma', aspect='auto', + extent=[x1_grid.min(), x1_grid.max(), x2_grid.min(), x2_grid.max()], + origin='lower', interpolation='nearest') + axes[2].set_title('Combined Tasks Target Function\n(Task1 + Task2)') + axes[2].set_xlabel('X1') + axes[2].set_ylabel('X2') + axes[2].set_xticks([-10, -5, 0, 5, 10]) + axes[2].set_yticks([-10, -5, 0, 5, 10]) + plt.colorbar(im3, ax=axes[2], label='Function Value') + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Target function analysis saved to {save_path}") + + +class SparseGatingNetwork(nn.Module): + """Sparse gating mechanism with multiple experts""" + def __init__(self, input_dim=2, hidden_dim=5, output_dim=2, num_experts=2, top_k=1): + super(SparseGatingNetwork, self).__init__() + self.num_experts = num_experts + self.top_k = min(top_k, num_experts) + + # Expert networks - simple MLPs + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + # nn.Linear(hidden_dim, hidden_dim//2), + # nn.ReLU(), + nn.Linear(hidden_dim, output_dim) + ) for _ in range(num_experts) + ]) + + # Gating network + self.gate = nn.Sequential( + nn.Linear(input_dim, hidden_dim//2), + nn.ReLU(), + nn.Linear(hidden_dim//2, num_experts) + ) + + def forward(self, x): + batch_size = x.size(0) + + # Compute gating weights + gate_logits = self.gate(x) # [batch_size, num_experts] + gate_weights = F.softmax(gate_logits, dim=1) + + # Apply sparsity: keep only top-k experts + top_k_weights, top_k_indices = torch.topk(gate_weights, self.top_k, dim=1) + + # Renormalize the top-k weights + top_k_weights = F.softmax(top_k_weights, dim=1) + + # Compute expert outputs + expert_outputs = [] + for i in range(self.num_experts): + expert_outputs.append(self.experts[i](x)) + expert_outputs = torch.stack(expert_outputs, dim=1) # [batch_size, num_experts, output_dim] + + # Weighted combination using only top-k experts + output = torch.zeros(batch_size, expert_outputs.size(-1), device=x.device) + for i in range(self.top_k): + expert_idx = top_k_indices[:, i] # [batch_size] + weights = top_k_weights[:, i:i+1] # [batch_size, 1] + + # Select expert outputs for each sample in batch + selected_outputs = expert_outputs[torch.arange(batch_size), expert_idx] # [batch_size, output_dim] + output += weights * selected_outputs + + # Compute load balancing loss + load_balance_loss = compute_load_balancing_loss(gate_weights, self.num_experts) + + return output, gate_weights, load_balance_loss + + +class PureMLP(nn.Module): + """Pure MLP baseline""" + def __init__(self, input_dim=2, hidden_dim=5, output_dim=2): + super(PureMLP, self).__init__() + + # Make the network comparable in size to the gating network + # Roughly same number of parameters as SparseGatingNetwork + self.network = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + # nn.Linear(hidden_dim * 2, hidden_dim), + # nn.ReLU(), + # nn.Linear(hidden_dim, hidden_dim//2), + # nn.ReLU(), + nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, x): + return self.network(x) + + +def compute_load_balancing_loss(gate_weights, num_experts): + """ + Compute load balancing loss to encourage even expert utilization + + Args: + gate_weights: [batch_size, num_experts] softmax gate weights + num_experts: number of experts + + Returns: + load_balancing_loss: scalar loss encouraging uniform expert usage + """ + # Compute the fraction of tokens routed to each expert + expert_fractions = gate_weights.mean(dim=0) # [num_experts] + + # Compute the fraction of tokens for which each expert has highest weight + top_expert_mask = torch.argmax(gate_weights, dim=1) # [batch_size] + expert_usage = torch.zeros(num_experts, device=gate_weights.device) + for i in range(num_experts): + expert_usage[i] = (top_expert_mask == i).float().mean() + + # Load balancing loss encourages uniform distribution (1/num_experts for each expert) + # Using coefficient of variation to measure distribution imbalance + target_fraction = 1.0 / num_experts + cv_loss = (expert_fractions - target_fraction).pow(2).sum() + + # Alternative: entropy-based loss to encourage uniform distribution + # entropy_loss = -(expert_fractions * torch.log(expert_fractions + 1e-8)).sum() + # max_entropy = torch.log(torch.tensor(num_experts, dtype=torch.float, device=gate_weights.device)) + # normalized_entropy_loss = 1.0 - entropy_loss / max_entropy + + return cv_loss + + +def analyze_expert_selection_patterns(expert_selection_history, num_experts=4): + """ + Analyze expert selection patterns over training + + Args: + expert_selection_history: List of epoch data with expert selections + num_experts: Number of experts in the model + + Returns: + Dictionary with analysis results + """ + if not expert_selection_history: + return {} + + analysis = { + 'expert_usage_over_time': [], + 'expert_specialization': [], + 'task_expert_correlation': [], + 'spatial_expert_patterns': [] + } + + for epoch_data in expert_selection_history: + epoch = epoch_data['epoch'] + + # Aggregate all selections for this epoch + all_expert_choices = [] + all_inputs = [] + all_targets = [] + all_gate_weights = [] + + for batch_data in epoch_data['selections']: + all_expert_choices.extend(batch_data['expert_choices']) + all_inputs.extend(batch_data['inputs']) + all_targets.extend(batch_data['targets']) + all_gate_weights.extend(batch_data['gate_weights']) + + if not all_expert_choices: + continue + + all_expert_choices = np.array(all_expert_choices) + all_inputs = np.array(all_inputs) + all_targets = np.array(all_targets) + all_gate_weights = np.array(all_gate_weights) + + # 1. Expert usage distribution + expert_counts = np.bincount(all_expert_choices, minlength=num_experts) + expert_usage = expert_counts / len(all_expert_choices) if len(all_expert_choices) > 0 else np.zeros(num_experts) + analysis['expert_usage_over_time'].append({ + 'epoch': epoch, + 'usage': expert_usage, + 'entropy': -np.sum(expert_usage * np.log(expert_usage + 1e-8)) + }) + + # 2. Task-expert correlation + # Analyze which experts are chosen for which target values + task_expert_corr = {} + for task_idx in range(2): # Assuming 2 tasks + task_values = all_targets[:, task_idx] + + # Divide task values into bins to see patterns + task_bins = np.digitize(task_values, bins=np.linspace(task_values.min(), task_values.max(), 5)) + + expert_by_task_bin = {} + for bin_idx in range(1, 6): + mask = task_bins == bin_idx + if np.sum(mask) > 0: + bin_expert_choices = all_expert_choices[mask] + bin_expert_counts = np.bincount(bin_expert_choices, minlength=num_experts) + bin_expert_usage = bin_expert_counts / len(bin_expert_choices) + expert_by_task_bin[bin_idx] = bin_expert_usage + + task_expert_corr[f'task_{task_idx}'] = expert_by_task_bin + + analysis['task_expert_correlation'].append({ + 'epoch': epoch, + 'correlation': task_expert_corr + }) + + # 3. Spatial patterns (input space regions) + # Divide input space into grid for higher resolution + x1_bins = np.digitize(all_inputs[:, 0], bins=np.linspace(-10, 10, VISUALIZATION_RESOLUTION + 1)) # +1 bins to get VISUALIZATION_RESOLUTION regions + x2_bins = np.digitize(all_inputs[:, 1], bins=np.linspace(-10, 10, VISUALIZATION_RESOLUTION + 1)) + + spatial_patterns = {} + for x1_bin in range(1, VISUALIZATION_RESOLUTION + 1): + for x2_bin in range(1, VISUALIZATION_RESOLUTION + 1): + region_mask = (x1_bins == x1_bin) & (x2_bins == x2_bin) + if np.sum(region_mask) > 0: + region_experts = all_expert_choices[region_mask] + region_expert_counts = np.bincount(region_experts, minlength=num_experts) + region_expert_usage = region_expert_counts / len(region_experts) + spatial_patterns[f'region_{x1_bin}_{x2_bin}'] = region_expert_usage + + analysis['spatial_expert_patterns'].append({ + 'epoch': epoch, + 'patterns': spatial_patterns + }) + + # 4. Expert specialization (how concentrated is each expert's usage) + expert_specialization = [] + for expert_idx in range(num_experts): + expert_weights = all_gate_weights[:, expert_idx] + # Use coefficient of variation as specialization measure + if np.std(expert_weights) > 0: + specialization = np.std(expert_weights) / (np.mean(expert_weights) + 1e-8) + else: + specialization = 0 + expert_specialization.append(specialization) + + analysis['expert_specialization'].append({ + 'epoch': epoch, + 'specialization': expert_specialization + }) + + return analysis + + +def count_parameters(model): + """Count trainable parameters""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def compute_gradient_conflict(model, batch_x, batch_y, criterion): + """ + Compute gradient conflict between tasks + Returns cosine similarity between task gradients and conflict metrics + """ + model.train() + + # Forward pass + if isinstance(model, SparseGatingNetwork): + outputs, _, _ = model(batch_x) + else: + outputs = model(batch_x) + + # Compute individual task losses + task1_loss = criterion(outputs[:, 0], batch_y[:, 0]) + task2_loss = criterion(outputs[:, 1], batch_y[:, 1]) + + # Clear gradients + model.zero_grad() + + # Compute gradients for task 1 + task1_loss.backward(retain_graph=True) + task1_grads = [] + + for param in model.parameters(): + if param.grad is not None: + task1_grads.append(param.grad.clone().flatten()) + else: + task1_grads.append(torch.zeros_like(param).flatten()) + + task1_grad_vector = torch.cat(task1_grads) + + # Clear gradients and compute gradients for task 2 + model.zero_grad() + task2_loss.backward(retain_graph=True) + task2_grads = [] + for param in model.parameters(): + if param.grad is not None: + task2_grads.append(param.grad.clone().flatten()) + task2_grad_vector = torch.cat(task2_grads) + + # Clear gradients after computation + model.zero_grad() + + # Compute cosine similarity between gradients + cosine_sim = F.cosine_similarity(task1_grad_vector.unsqueeze(0), + task2_grad_vector.unsqueeze(0)).item() + + # Compute gradient norms + task1_norm = torch.norm(task1_grad_vector).item() + task2_norm = torch.norm(task2_grad_vector).item() + + # Conflict metrics + conflict_angle = np.arccos(np.clip(cosine_sim, -1, 1)) * 180 / np.pi # in degrees + is_conflicting = cosine_sim < 0 # negative cosine means conflict + + return { + 'cosine_similarity': cosine_sim, + 'conflict_angle': conflict_angle, + 'is_conflicting': is_conflicting, + 'task1_grad_norm': task1_norm, + 'task2_grad_norm': task2_norm, + 'task1_loss': task1_loss.item(), + 'task2_loss': task2_loss.item() + } + + +def compute_expert_gradient_conflicts(model, batch_x, batch_y, criterion): + """ + Compute gradient conflicts between tasks for each expert in the sparse gating network + Returns conflict metrics for each expert + """ + if not isinstance(model, SparseGatingNetwork): + return {} + + model.train() + expert_conflicts = {} + + # For each expert, compute the gradient conflicts between tasks + for expert_idx in range(model.num_experts): + expert = model.experts[expert_idx] + + # Forward pass through this specific expert + expert_outputs = expert(batch_x) # [batch_size, output_dim] + + # Compute individual task losses for this expert + task1_loss = criterion(expert_outputs[:, 0], batch_y[:, 0]) + task2_loss = criterion(expert_outputs[:, 1], batch_y[:, 1]) + + # Clear gradients + expert.zero_grad() + + # Compute gradients for task 1 + task1_loss.backward(retain_graph=True) + task1_grads = [] + + for param in expert.parameters(): + if param.grad is not None: + task1_grads.append(param.grad.clone().flatten()) + else: + task1_grads.append(torch.zeros_like(param).flatten()) + + if task1_grads: + task1_grad_vector = torch.cat(task1_grads) + else: + continue + + # Clear gradients and compute gradients for task 2 + expert.zero_grad() + task2_loss.backward(retain_graph=True) + task2_grads = [] + + for param in expert.parameters(): + if param.grad is not None: + task2_grads.append(param.grad.clone().flatten()) + else: + task2_grads.append(torch.zeros_like(param).flatten()) + + if task2_grads: + task2_grad_vector = torch.cat(task2_grads) + else: + continue + + # Clear gradients after computation + expert.zero_grad() + + # Compute cosine similarity between gradients + if torch.norm(task1_grad_vector) > 1e-8 and torch.norm(task2_grad_vector) > 1e-8: + cosine_sim = F.cosine_similarity(task1_grad_vector.unsqueeze(0), + task2_grad_vector.unsqueeze(0)).item() + + # Compute gradient norms + task1_norm = torch.norm(task1_grad_vector).item() + task2_norm = torch.norm(task2_grad_vector).item() + + # Conflict metrics + conflict_angle = np.arccos(np.clip(cosine_sim, -1, 1)) * 180 / np.pi # in degrees + is_conflicting = cosine_sim < 0 # negative cosine means conflict + + expert_conflicts[f'expert_{expert_idx}'] = { + 'cosine_similarity': cosine_sim, + 'conflict_angle': conflict_angle, + 'is_conflicting': is_conflicting, + 'task1_grad_norm': task1_norm, + 'task2_grad_norm': task2_norm, + 'task1_loss': task1_loss.item(), + 'task2_loss': task2_loss.item() + } + + return expert_conflicts + + +def train_model(model, train_loader, val_loader, num_epochs=30, lr=0.001, track_conflicts=False, + load_balance_weight=0.01, track_expert_selection=False, track_expert_conflicts=False): + """Training function with optional gradient conflict tracking and load balancing""" + optimizer = torch.optim.Adam(model.parameters(), lr=lr) + criterion = nn.MSELoss() + + train_losses = [] + val_losses = [] + conflict_history = [] + expert_selection_history = [] + expert_conflict_history = [] # New: track expert-specific conflicts + + for epoch in tqdm(range(num_epochs), desc="Training"): + # Training + model.train() + train_loss = 0.0 + epoch_conflicts = [] + epoch_expert_conflicts = [] # New: store expert conflicts for this epoch + + epoch_expert_selections = [] + + for batch_idx, (batch_x, batch_y) in enumerate(train_loader): + # Track gradient conflicts every 10 batches if requested + if track_conflicts and batch_idx % 10 == 0: + conflict_metrics = compute_gradient_conflict(model, batch_x, batch_y, criterion) + epoch_conflicts.append(conflict_metrics) + + # Track expert gradient conflicts every 10 batches if requested + if track_expert_conflicts and batch_idx % 10 == 0: + expert_conflict_metrics = compute_expert_gradient_conflicts(model, batch_x, batch_y, criterion) + if expert_conflict_metrics: # Only add if we have expert conflicts (i.e., for gating model) + epoch_expert_conflicts.append(expert_conflict_metrics) + + optimizer.zero_grad() + + if isinstance(model, SparseGatingNetwork): + outputs, gate_weights, load_balance_loss = model(batch_x) + + # Track expert selection every 20 batches if requested + if track_expert_selection and batch_idx % 20 == 0: + expert_choices = torch.argmax(gate_weights, dim=1) # [batch_size] + epoch_expert_selections.append({ + 'batch_idx': batch_idx, + 'expert_choices': expert_choices.cpu().numpy(), + 'gate_weights': gate_weights.detach().cpu().numpy(), + 'inputs': batch_x.cpu().numpy(), + 'targets': batch_y.cpu().numpy() + }) + + # Combine main loss with load balancing loss + main_loss = criterion(outputs, batch_y) + loss = main_loss + load_balance_weight * load_balance_loss + else: + outputs = model(batch_x) + loss = criterion(outputs, batch_y) + + loss.backward() + optimizer.step() + train_loss += loss.item() + + # Store conflict metrics for this epoch + if track_conflicts and epoch_conflicts: + # Average conflict metrics across batches in this epoch + avg_conflict = { + 'cosine_similarity': np.mean([c['cosine_similarity'] for c in epoch_conflicts]), + 'conflict_angle': np.mean([c['conflict_angle'] for c in epoch_conflicts]), + 'is_conflicting': np.mean([c['is_conflicting'] for c in epoch_conflicts]), + 'task1_grad_norm': np.mean([c['task1_grad_norm'] for c in epoch_conflicts]), + 'task2_grad_norm': np.mean([c['task2_grad_norm'] for c in epoch_conflicts]) + } + conflict_history.append(avg_conflict) + + # Store expert conflict metrics for this epoch + if track_expert_conflicts and epoch_expert_conflicts: + # Average expert conflict metrics across batches in this epoch + expert_names = list(epoch_expert_conflicts[0].keys()) if epoch_expert_conflicts else [] + epoch_expert_avg = {'epoch': epoch} + + for expert_name in expert_names: + expert_conflicts_for_epoch = [batch_data[expert_name] for batch_data in epoch_expert_conflicts if expert_name in batch_data] + if expert_conflicts_for_epoch: + epoch_expert_avg[expert_name] = { + 'cosine_similarity': np.mean([c['cosine_similarity'] for c in expert_conflicts_for_epoch]), + 'conflict_angle': np.mean([c['conflict_angle'] for c in expert_conflicts_for_epoch]), + 'is_conflicting': np.mean([c['is_conflicting'] for c in expert_conflicts_for_epoch]), + 'task1_grad_norm': np.mean([c['task1_grad_norm'] for c in expert_conflicts_for_epoch]), + 'task2_grad_norm': np.mean([c['task2_grad_norm'] for c in expert_conflicts_for_epoch]) + } + + expert_conflict_history.append(epoch_expert_avg) + + # Store expert selection data for this epoch + if track_expert_selection and epoch_expert_selections: + expert_selection_history.append({ + 'epoch': epoch, + 'selections': epoch_expert_selections + }) + + # Validation + model.eval() + val_loss = 0.0 + with torch.no_grad(): + for batch_x, batch_y in val_loader: + if isinstance(model, SparseGatingNetwork): + outputs, _, _ = model(batch_x) + else: + outputs = model(batch_x) + loss = criterion(outputs, batch_y) + val_loss += loss.item() + + train_losses.append(train_loss / len(train_loader)) + val_losses.append(val_loss / len(val_loader)) + + if epoch % 20 == 0: + print(f"Epoch {epoch}: Train Loss = {train_losses[-1]:.4f}, Val Loss = {val_losses[-1]:.4f}") + if track_conflicts and conflict_history: + latest_conflict = conflict_history[-1] + print(f" Gradient Conflict: Angle = {latest_conflict['conflict_angle']:.1f}°, " + f"Cosine Sim = {latest_conflict['cosine_similarity']:.3f}") + if track_expert_conflicts and expert_conflict_history: + latest_expert_conflicts = expert_conflict_history[-1] + print(" Expert Conflicts:") + for expert_name, conflicts in latest_expert_conflicts.items(): + if expert_name != 'epoch': + print(f" {expert_name}: {conflicts['conflict_angle']:.1f}°") + + return train_losses, val_losses, conflict_history, expert_selection_history, expert_conflict_history + + +def evaluate_model(model, test_loader): + """Evaluate model performance""" + model.eval() + criterion = nn.MSELoss() + + total_loss = 0.0 + task1_loss = 0.0 + task2_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for batch_x, batch_y in test_loader: + if isinstance(model, SparseGatingNetwork): + outputs, gate_weights, _ = model(batch_x) + else: + outputs = model(batch_x) + gate_weights = None + + # Overall loss + loss = criterion(outputs, batch_y) + total_loss += loss.item() + + # Per-task losses + task1_loss += criterion(outputs[:, 0], batch_y[:, 0]).item() + task2_loss += criterion(outputs[:, 1], batch_y[:, 1]).item() + + num_batches += 1 + + return { + 'total_loss': total_loss / num_batches, + 'task1_loss': task1_loss / num_batches, + 'task2_loss': task2_loss / num_batches, + 'gate_weights': gate_weights + } + + +def compute_rolling_expert_conflicts(expert_conflict_history, window_size=5): + """ + Compute rolling statistics for expert gradient conflicts over recent epochs + + Args: + expert_conflict_history: List of expert conflict data per epoch + window_size: Number of recent epochs to consider (default 5) + + Returns: + Dictionary with rolling statistics for each expert + """ + if not expert_conflict_history or len(expert_conflict_history) == 0: + return {} + + rolling_stats = {} + + # Get expert names from the first epoch that has data + expert_names = [] + for epoch_data in expert_conflict_history: + if len(epoch_data) > 1: # More than just 'epoch' key + expert_names = [k for k in epoch_data.keys() if k != 'epoch'] + break + + if not expert_names: + return {} + + for expert_name in expert_names: + rolling_stats[expert_name] = { + 'epochs': [], + 'rolling_conflict_angle': [], + 'rolling_cosine_similarity': [], + 'rolling_conflicting_rate': [], + 'rolling_task1_norm': [], + 'rolling_task2_norm': [] + } + + # Compute rolling statistics for each epoch + for i, epoch_data in enumerate(expert_conflict_history): + epoch = epoch_data.get('epoch', i) + + # Determine the window for this epoch (recent 5 epochs) + start_idx = max(0, i - window_size + 1) + end_idx = i + 1 + window_data = expert_conflict_history[start_idx:end_idx] + + # For each expert, compute rolling statistics + for expert_name in expert_names: + if expert_name in epoch_data: + # Collect data from the window + window_conflicts = [] + for window_epoch in window_data: + if expert_name in window_epoch: + window_conflicts.append(window_epoch[expert_name]) + + if window_conflicts: + # Compute rolling averages + rolling_conflict_angle = np.mean([c['conflict_angle'] for c in window_conflicts]) + rolling_cosine_sim = np.mean([c['cosine_similarity'] for c in window_conflicts]) + rolling_conflicting_rate = np.mean([c['is_conflicting'] for c in window_conflicts]) + rolling_task1_norm = np.mean([c['task1_grad_norm'] for c in window_conflicts]) + rolling_task2_norm = np.mean([c['task2_grad_norm'] for c in window_conflicts]) + + # Store results + rolling_stats[expert_name]['epochs'].append(epoch) + rolling_stats[expert_name]['rolling_conflict_angle'].append(rolling_conflict_angle) + rolling_stats[expert_name]['rolling_cosine_similarity'].append(rolling_cosine_sim) + rolling_stats[expert_name]['rolling_conflicting_rate'].append(rolling_conflicting_rate) + rolling_stats[expert_name]['rolling_task1_norm'].append(rolling_task1_norm) + rolling_stats[expert_name]['rolling_task2_norm'].append(rolling_task2_norm) + + return rolling_stats + + +def plot_expert_gradient_conflicts(expert_conflict_history, save_path='expert_gradient_conflicts.png', window_size=5): + """ + Plot expert gradient conflict analysis over epochs with rolling statistics + + Args: + expert_conflict_history: List of expert conflict data per epoch + save_path: Path to save the plot + window_size: Window size for rolling statistics (default 5) + """ + if not expert_conflict_history: + print("No expert conflict data to plot") + return + + # Compute rolling statistics + rolling_stats = compute_rolling_expert_conflicts(expert_conflict_history, window_size) + + if not rolling_stats: + print("No valid expert conflict data found") + return + + expert_names = list(rolling_stats.keys()) + num_experts = len(expert_names) + + # Create subplots: 2 rows, multiple columns + fig, axes = plt.subplots(2, 2, figsize=(16, 10)) + + # Plot 1: Conflict angles over time (rolling average) + ax1 = axes[0, 0] + for expert_name in expert_names: + data = rolling_stats[expert_name] + if data['epochs'] and data['rolling_conflict_angle']: + ax1.plot(data['epochs'], data['rolling_conflict_angle'], + label=expert_name.replace('_', ' ').title(), marker='o', markersize=4) + + ax1.set_title(f'Expert Gradient Conflict Angles (Rolling {window_size}-Epoch Average)') + ax1.set_xlabel('Epoch') + ax1.set_ylabel('Conflict Angle (degrees)') + ax1.legend() + ax1.grid(True, alpha=0.3) + ax1.axhline(y=90, color='gray', linestyle='--', alpha=0.7, label='No conflict (90°)') + + # Plot 2: Cosine similarity over time (rolling average) + ax2 = axes[0, 1] + for expert_name in expert_names: + data = rolling_stats[expert_name] + if data['epochs'] and data['rolling_cosine_similarity']: + ax2.plot(data['epochs'], data['rolling_cosine_similarity'], + label=expert_name.replace('_', ' ').title(), marker='o', markersize=4) + + ax2.set_title(f'Expert Gradient Cosine Similarity (Rolling {window_size}-Epoch Average)') + ax2.set_xlabel('Epoch') + ax2.set_ylabel('Cosine Similarity') + ax2.legend() + ax2.grid(True, alpha=0.3) + ax2.axhline(y=0, color='gray', linestyle='--', alpha=0.7, label='No correlation (0)') + + # Plot 3: Conflicting rate over time (rolling average) + ax3 = axes[1, 0] + for expert_name in expert_names: + data = rolling_stats[expert_name] + if data['epochs'] and data['rolling_conflicting_rate']: + conflicting_rate_percent = [x * 100 for x in data['rolling_conflicting_rate']] # Convert to percentage + ax3.plot(data['epochs'], conflicting_rate_percent, + label=expert_name.replace('_', ' ').title(), marker='o', markersize=4) + + ax3.set_title(f'Expert Gradient Conflicting Rate (Rolling {window_size}-Epoch Average)') + ax3.set_xlabel('Epoch') + ax3.set_ylabel('Conflicting Rate (%)') + ax3.legend() + ax3.grid(True, alpha=0.3) + + # Plot 4: Gradient norms comparison (rolling average) + ax4 = axes[1, 1] + colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown'] + for i, expert_name in enumerate(expert_names): + data = rolling_stats[expert_name] + if data['epochs'] and data['rolling_task1_norm'] and data['rolling_task2_norm']: + color = colors[i % len(colors)] + ax4.plot(data['epochs'], data['rolling_task1_norm'], + label=f'{expert_name.replace("_", " ").title()} - Task 1', + color=color, linestyle='-', marker='o', markersize=3) + ax4.plot(data['epochs'], data['rolling_task2_norm'], + label=f'{expert_name.replace("_", " ").title()} - Task 2', + color=color, linestyle='--', marker='s', markersize=3) + + ax4.set_title(f'Expert Gradient Norms (Rolling {window_size}-Epoch Average)') + ax4.set_xlabel('Epoch') + ax4.set_ylabel('Gradient Norm') + ax4.legend(fontsize='small') + ax4.grid(True, alpha=0.3) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Expert gradient conflict analysis saved to {save_path}") + + # Print summary statistics + print(f"\nExpert Gradient Conflict Summary (Last {window_size} epochs):") + print("=" * 60) + for expert_name in expert_names: + data = rolling_stats[expert_name] + if data['rolling_conflict_angle']: + latest_angle = data['rolling_conflict_angle'][-1] + latest_cosine = data['rolling_cosine_similarity'][-1] + latest_conflicting_rate = data['rolling_conflicting_rate'][-1] * 100 + print(f"{expert_name.replace('_', ' ').title()}:") + print(f" Average Conflict Angle: {latest_angle:.1f}°") + print(f" Average Cosine Similarity: {latest_cosine:.3f}") + print(f" Conflicting Rate: {latest_conflicting_rate:.1f}%") + + +def plot_expert_selection_analysis(expert_analysis, save_path='expert_selection_analysis.png'): + """Plot expert selection patterns over time""" + if not expert_analysis: + print("No expert selection data to plot") + return + + # Get number of experts from the data + num_experts = len(expert_analysis['expert_usage_over_time'][0]['usage']) + + # Create subplot grid: top row has 3 plots, bottom row has up to num_experts plots + fig, axes = plt.subplots(2, max(3, num_experts), figsize=(18, 12)) + + # 1. Expert usage over time + epochs = [data['epoch'] for data in expert_analysis['expert_usage_over_time']] + num_experts = len(expert_analysis['expert_usage_over_time'][0]['usage']) + + for expert_idx in range(num_experts): + usage_over_time = [data['usage'][expert_idx] for data in expert_analysis['expert_usage_over_time']] + axes[0, 0].plot(epochs, usage_over_time, label=f'Expert {expert_idx}', marker='o') + + axes[0, 0].set_title('Expert Usage Over Time') + axes[0, 0].set_xlabel('Epoch') + axes[0, 0].set_ylabel('Usage Probability') + axes[0, 0].legend() + axes[0, 0].grid(True, alpha=0.3) + axes[0, 0].axhline(y=1.0/num_experts, color='gray', linestyle='--', alpha=0.7, label='Uniform') + + # 2. Expert selection entropy (diversity measure) + entropies = [data['entropy'] for data in expert_analysis['expert_usage_over_time']] + max_entropy = np.log(num_experts) + + axes[0, 1].plot(epochs, entropies, 'b-', marker='o', label='Selection Entropy') + axes[0, 1].axhline(y=max_entropy, color='red', linestyle='--', alpha=0.7, label='Max Entropy (Uniform)') + axes[0, 1].set_title('Expert Selection Diversity') + axes[0, 1].set_xlabel('Epoch') + axes[0, 1].set_ylabel('Entropy') + axes[0, 1].legend() + axes[0, 1].grid(True, alpha=0.3) + + # 3. Expert specialization over time + for expert_idx in range(num_experts): + specialization_over_time = [data['specialization'][expert_idx] for data in expert_analysis['expert_specialization']] + axes[0, 2].plot(epochs, specialization_over_time, label=f'Expert {expert_idx}', marker='o') + + axes[0, 2].set_title('Expert Specialization Over Time') + axes[0, 2].set_xlabel('Epoch') + axes[0, 2].set_ylabel('Specialization (CV)') + axes[0, 2].legend() + axes[0, 2].grid(True, alpha=0.3) + + # 4. Final spatial patterns (last epoch) + if expert_analysis['spatial_expert_patterns']: + final_spatial = expert_analysis['spatial_expert_patterns'][-1]['patterns'] + regions = list(final_spatial.keys()) + + # Create heatmap for each expert + for expert_idx in range(num_experts): # Show all experts + region_usage = [final_spatial[region][expert_idx] if region in final_spatial else 0 + for region in regions] + + if expert_idx < axes.shape[1]: # Check if we have enough columns + ax = axes[1, expert_idx] + + # Reshape for grid visualization + grid_data = np.zeros((VISUALIZATION_RESOLUTION, VISUALIZATION_RESOLUTION)) + for i, region in enumerate(regions): + if len(region.split('_')) >= 3: + x_idx = int(region.split('_')[1]) - 1 + y_idx = int(region.split('_')[2]) - 1 + if 0 <= x_idx < VISUALIZATION_RESOLUTION and 0 <= y_idx < VISUALIZATION_RESOLUTION: + grid_data[y_idx, x_idx] = final_spatial[region][expert_idx] + + # Set extent to match the actual coordinate system (-10 to 10) + im = ax.imshow(grid_data, cmap='Blues', aspect='auto', interpolation='nearest', + extent=[-10, 10, -10, 10], origin='lower', vmin=0, vmax=1) + ax.set_title(f'Expert {expert_idx} Spatial Pattern (Final)') + ax.set_xlabel('X1') + ax.set_ylabel('X2') + + # Set ticks to match coordinate system + ax.set_xticks([-10, -5, 0, 5, 10]) + ax.set_yticks([-10, -5, 0, 5, 10]) + + plt.colorbar(im, ax=ax) + + # If we have more subplots than experts, hide the empty ones + if axes.shape[1] > num_experts: + for idx in range(num_experts, axes.shape[1]): + axes[1, idx].set_visible(False) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + print(f"Expert selection analysis saved to {save_path}") + + +def plot_results(gating_results, mlp_results): + """Plot comparison results with gradient conflict analysis""" + fig, axes = plt.subplots(2, 3, figsize=(18, 10)) + + # Training curves + axes[0, 0].plot(gating_results['train_losses'], label='Sparse Gating', color='red') + axes[0, 0].plot(mlp_results['train_losses'], label='Pure MLP', color='blue') + axes[0, 0].set_title('Training Loss') + axes[0, 0].set_xlabel('Epoch') + axes[0, 0].set_ylabel('Loss') + axes[0, 0].legend() + axes[0, 0].grid(True) + + # Validation curves + axes[0, 1].plot(gating_results['val_losses'], label='Sparse Gating', color='red') + axes[0, 1].plot(mlp_results['val_losses'], label='Pure MLP', color='blue') + axes[0, 1].set_title('Validation Loss') + axes[0, 1].set_xlabel('Epoch') + axes[0, 1].set_ylabel('Loss') + axes[0, 1].legend() + axes[0, 1].grid(True) + + # Gradient conflict over time + if gating_results.get('conflict_history') and mlp_results.get('conflict_history'): + gating_conflicts = [c['conflict_angle'] for c in gating_results['conflict_history']] + mlp_conflicts = [c['conflict_angle'] for c in mlp_results['conflict_history']] + + epochs = range(len(gating_conflicts)) + axes[0, 2].plot(epochs, gating_conflicts, label='Sparse Gating', color='red') + axes[0, 2].plot(epochs, mlp_conflicts, label='Pure MLP', color='blue') + axes[0, 2].set_title('Gradient Conflict Angle') + axes[0, 2].set_xlabel('Epoch') + axes[0, 2].set_ylabel('Angle (degrees)') + axes[0, 2].legend() + axes[0, 2].grid(True) + axes[0, 2].axhline(y=90, color='gray', linestyle='--', alpha=0.7, label='No conflict') + else: + axes[0, 2].text(0.5, 0.5, 'No conflict data\navailable', + ha='center', va='center', transform=axes[0, 2].transAxes) + axes[0, 2].set_title('Gradient Conflict Angle') + + # Per-task performance comparison + methods = ['Sparse Gating', 'Pure MLP'] + task1_losses = [gating_results['test_eval']['task1_loss'], mlp_results['test_eval']['task1_loss']] + task2_losses = [gating_results['test_eval']['task2_loss'], mlp_results['test_eval']['task2_loss']] + + x = np.arange(len(methods)) + width = 0.35 + + axes[1, 0].bar(x - width/2, task1_losses, width, label='Task 1', alpha=0.8) + axes[1, 0].bar(x + width/2, task2_losses, width, label='Task 2', alpha=0.8) + axes[1, 0].set_title('Per-Task Test Loss') + axes[1, 0].set_ylabel('Loss') + axes[1, 0].set_xticks(x) + axes[1, 0].set_xticklabels(methods) + axes[1, 0].legend() + axes[1, 0].grid(True, alpha=0.3) + + # Parameter count comparison + param_counts = [gating_results['param_count'], mlp_results['param_count']] + axes[1, 1].bar(methods, param_counts, alpha=0.8, color=['red', 'blue']) + axes[1, 1].set_title('Parameter Count') + axes[1, 1].set_ylabel('Number of Parameters') + axes[1, 1].grid(True, alpha=0.3) + + # Average gradient conflict comparison + if gating_results.get('conflict_history') and mlp_results.get('conflict_history'): + gating_avg_conflict = np.mean([c['conflict_angle'] for c in gating_results['conflict_history']]) + mlp_avg_conflict = np.mean([c['conflict_angle'] for c in mlp_results['conflict_history']]) + + conflict_angles = [gating_avg_conflict, mlp_avg_conflict] + bars = axes[1, 2].bar(methods, conflict_angles, alpha=0.8, color=['red', 'blue']) + axes[1, 2].set_title('Average Gradient Conflict') + axes[1, 2].set_ylabel('Angle (degrees)') + axes[1, 2].axhline(y=90, color='gray', linestyle='--', alpha=0.7) + axes[1, 2].grid(True, alpha=0.3) + + # Add value labels on bars + for bar, value in zip(bars, conflict_angles): + axes[1, 2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, + f'{value:.1f}°', ha='center', va='bottom') + else: + axes[1, 2].text(0.5, 0.5, 'No conflict data\navailable', + ha='center', va='center', transform=axes[1, 2].transAxes) + axes[1, 2].set_title('Average Gradient Conflict') + + plt.tight_layout() + plt.savefig('multitask_gating_comparison.png', dpi=300, bbox_inches='tight') + plt.close() + + +def run_experiment(): + """Main experiment function""" + print("Starting Multi-task Learning Experiment: Sparse Gating vs Pure MLP") + print("=" * 60) + + # Generate dataset + dataset = ToyTaskDataset(num_samples=20000) + X, Y = dataset.generate_data() + + # Split data + train_size = int(0.7 * len(X)) + val_size = int(0.15 * len(X)) + + train_X, train_Y = X[:train_size], Y[:train_size] + val_X, val_Y = X[train_size:train_size+val_size], Y[train_size:train_size+val_size] + test_X, test_Y = X[train_size+val_size:], Y[train_size+val_size:] + + # Create data loaders + train_dataset = torch.utils.data.TensorDataset(train_X, train_Y) + val_dataset = torch.utils.data.TensorDataset(val_X, val_Y) + test_dataset = torch.utils.data.TensorDataset(test_X, test_Y) + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=24, shuffle=True) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=24, shuffle=False) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=24, shuffle=False) + + print(f"Data split: Train={len(train_X)}, Val={len(val_X)}, Test={len(test_X)}") + + # Initialize models + gating_model = SparseGatingNetwork(input_dim=2, hidden_dim=32, output_dim=2, num_experts=4, top_k=1) + mlp_model = PureMLP(input_dim=2, hidden_dim=32, output_dim=2) + + print(f"Sparse Gating Model Parameters: {count_parameters(gating_model):,}") + print(f"Pure MLP Model Parameters: {count_parameters(mlp_model):,}") + print() + + # Train models with gradient conflict tracking and expert selection tracking + print("Training Sparse Gating Network...") + start_time = time.time() + gating_train_losses, gating_val_losses, gating_conflicts, gating_expert_history, gating_expert_conflicts = train_model( + gating_model, train_loader, val_loader, num_epochs=100, track_conflicts=True, + track_expert_selection=True, track_expert_conflicts=True) + gating_training_time = time.time() - start_time + + print("\nTraining Pure MLP...") + start_time = time.time() + mlp_train_losses, mlp_val_losses, mlp_conflicts, mlp_expert_history, mlp_expert_conflicts = train_model( + mlp_model, train_loader, val_loader, num_epochs=100, track_conflicts=True) + mlp_training_time = time.time() - start_time + + # Evaluate models + print("\nEvaluating models...") + gating_eval = evaluate_model(gating_model, test_loader) + mlp_eval = evaluate_model(mlp_model, test_loader) + + # Analyze expert selection patterns for gating model + expert_analysis = None + if gating_expert_history: + expert_analysis = analyze_expert_selection_patterns(gating_expert_history, num_experts=4) + + # Prepare results + gating_results = { + 'train_losses': gating_train_losses, + 'val_losses': gating_val_losses, + 'test_eval': gating_eval, + 'param_count': count_parameters(gating_model), + 'training_time': gating_training_time, + 'conflict_history': gating_conflicts, + 'expert_selection_history': gating_expert_history, + 'expert_analysis': expert_analysis, + 'expert_conflict_history': gating_expert_conflicts + } + + mlp_results = { + 'train_losses': mlp_train_losses, + 'val_losses': mlp_val_losses, + 'test_eval': mlp_eval, + 'param_count': count_parameters(mlp_model), + 'training_time': mlp_training_time, + 'conflict_history': mlp_conflicts, + 'expert_conflict_history': mlp_expert_conflicts + } + + # Print results + print("\n" + "="*80) + print("RESULTS SUMMARY") + print("="*80) + print(f"{'Metric':<25} {'Sparse Gating':<15} {'Pure MLP':<15} {'Winner'}") + print("-" * 80) + print(f"{'Total Test Loss':<25} {gating_eval['total_loss']:<15.4f} {mlp_eval['total_loss']:<15.4f} {'Gating' if gating_eval['total_loss'] < mlp_eval['total_loss'] else 'MLP'}") + print(f"{'Task 1 Test Loss':<25} {gating_eval['task1_loss']:<15.4f} {mlp_eval['task1_loss']:<15.4f} {'Gating' if gating_eval['task1_loss'] < mlp_eval['task1_loss'] else 'MLP'}") + print(f"{'Task 2 Test Loss':<25} {gating_eval['task2_loss']:<15.4f} {mlp_eval['task2_loss']:<15.4f} {'Gating' if gating_eval['task2_loss'] < mlp_eval['task2_loss'] else 'MLP'}") + print(f"{'Parameters':<25} {count_parameters(gating_model):<15,} {count_parameters(mlp_model):<15,} {'Gating' if count_parameters(gating_model) < count_parameters(mlp_model) else 'MLP'}") + print(f"{'Training Time (s)':<25} {gating_training_time:<15.2f} {mlp_training_time:<15.2f} {'Gating' if gating_training_time < mlp_training_time else 'MLP'}") + + # Gradient conflict analysis + if gating_conflicts and mlp_conflicts: + gating_avg_conflict = np.mean([c['conflict_angle'] for c in gating_conflicts]) + mlp_avg_conflict = np.mean([c['conflict_angle'] for c in mlp_conflicts]) + gating_conflicting_rate = np.mean([c['is_conflicting'] for c in gating_conflicts]) + mlp_conflicting_rate = np.mean([c['is_conflicting'] for c in mlp_conflicts]) + + print("\n" + "="*80) + print("GRADIENT CONFLICT ANALYSIS") + print("="*80) + print(f"{'Avg Conflict Angle (°)':<25} {gating_avg_conflict:<15.1f} {mlp_avg_conflict:<15.1f} {'Gating' if gating_avg_conflict < mlp_avg_conflict else 'MLP'}") + print(f"{'Conflicting Rate (%)':<25} {gating_conflicting_rate*100:<15.1f} {mlp_conflicting_rate*100:<15.1f} {'Gating' if gating_conflicting_rate < mlp_conflicting_rate else 'MLP'}") + + # Final gradient conflict on test data + test_batch = next(iter(test_loader)) + test_x, test_y = test_batch + gating_final_conflict = compute_gradient_conflict(gating_model, test_x, test_y, nn.MSELoss()) + mlp_final_conflict = compute_gradient_conflict(mlp_model, test_x, test_y, nn.MSELoss()) + + print(f"{'Final Test Conflict (°)':<25} {gating_final_conflict['conflict_angle']:<15.1f} {mlp_final_conflict['conflict_angle']:<15.1f} {'Gating' if gating_final_conflict['conflict_angle'] < mlp_final_conflict['conflict_angle'] else 'MLP'}") + + # Print detailed analysis + print(f"\nDETAILED CONFLICT ANALYSIS:") + print(f"Gating - Training avg vs Final test: {gating_avg_conflict:.1f}° vs {gating_final_conflict['conflict_angle']:.1f}° (diff: {abs(gating_avg_conflict - gating_final_conflict['conflict_angle']):.1f}°)") + print(f"MLP - Training avg vs Final test: {mlp_avg_conflict:.1f}° vs {mlp_final_conflict['conflict_angle']:.1f}° (diff: {abs(mlp_avg_conflict - mlp_final_conflict['conflict_angle']):.1f}°)") + + print("\nNote: Lower conflict angle indicates better alignment between task gradients") + print("Angles < 90° indicate cooperative gradients, > 90° indicate conflicting gradients") + print("Large difference between training avg and final test may indicate:") + print("- Different data distributions (train vs test)") + print("- Model still learning during training (vs converged at end)") + print("- Load balancing effects during training") + + # Analyze expert selection patterns (only for gating model) + if expert_analysis: + print("\nAnalyzing expert selection patterns...") + plot_expert_selection_analysis(expert_analysis) + + # Print summary of expert selection + print("\nEXPERT SELECTION SUMMARY:") + print("="*50) + + # Final expert usage + final_usage = expert_analysis['expert_usage_over_time'][-1]['usage'] + print(f"Final Expert Usage Distribution:") + for i, usage in enumerate(final_usage): + print(f" Expert {i}: {usage:.3f} ({usage*100:.1f}%)") + + # Expert usage entropy over time + initial_entropy = expert_analysis['expert_usage_over_time'][0]['entropy'] + final_entropy = expert_analysis['expert_usage_over_time'][-1]['entropy'] + max_entropy = np.log(4) # 4 experts + + print(f"\nExpert Selection Diversity:") + print(f" Initial Entropy: {initial_entropy:.3f} (Normalized: {initial_entropy/max_entropy:.3f})") + print(f" Final Entropy: {final_entropy:.3f} (Normalized: {final_entropy/max_entropy:.3f})") + print(f" Max Possible Entropy: {max_entropy:.3f}") + + # Most specialized expert /fs-computility/niuyazhe/tangjia/github/ + final_specialization = expert_analysis['expert_specialization'][-1]['specialization'] + most_specialized_expert = np.argmax(final_specialization) + print(f"\nMost Specialized Expert: Expert {most_specialized_expert} (Specialization: {final_specialization[most_specialized_expert]:.3f})") + + # Analyze expert gradient conflicts (only for gating model) + if gating_expert_conflicts: + print("\nAnalyzing expert gradient conflicts...") + plot_expert_gradient_conflicts(gating_expert_conflicts, window_size=5) + + # Plot results + plot_results(gating_results, mlp_results) + + # Plot gradient steepness analysis for the toy tasks + print("\nGenerating gradient steepness analysis...") + plot_gradient_steepness_analysis() + + # Plot gradient direction analysis for the toy tasks + print("Generating gradient direction analysis...") + plot_gradient_direction_analysis() + + # Plot target function analysis + print("Generating target function analysis...") + plot_target_function_analysis() + + return gating_results, mlp_results + + +if __name__ == "__main__": + gating_results, mlp_results = run_experiment() \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py index bdc5e4f7a..1184bf90e 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py @@ -143,7 +143,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # num_layers=12, # todo num_heads=24, - embed_dim=768, + embed_dim=768, #768 obs_type='image', env_num=8, task_num=len(env_id_list), @@ -333,7 +333,7 @@ def create_env_manager(): torchrun --nproc_per_node=8 ./zoo/atari/config/atari_unizero_multitask_segment_8games_ddp_config.py """ - +# /fs-computility/niuyazhe/tangjia/code/LightZero-dev-multitask-balance-clean/zoo/atari/config/atari_unizero_multitask_segment_ddp_config.py from lzero.entry import train_unizero_multitask_segment_ddp from ding.utils import DDPContext import os diff --git a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py index cddaae311..129123a6f 100644 --- a/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py +++ b/zoo/atari/config/atari_unizero_multitask_segment_ddp_config_debug.py @@ -64,8 +64,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu policy=dict( multi_gpu=True, # Very important for ddp only_use_moco_stats=False, - # use_moco=False, # ==============TODO============== - use_moco=True, # ==============TODO: moco============== + use_moco=False, # ==============TODO============== + # use_moco=True, # ==============TODO: moco============== learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000))), grad_correct_params=dict( MoCo_beta=0.5, MoCo_beta_sigma=0.5, MoCo_gamma=0.1, MoCo_gamma_sigma=0.5, MoCo_rho=0, @@ -129,7 +129,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # num_layers=12, # todo num_heads=24, - embed_dim=768, + embed_dim=768,#768 obs_type='image', env_num=8, task_num=len(env_id_list), @@ -142,8 +142,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu num_experts_in_moe_head=4, moe_in_transformer=False, - multiplication_moe_in_transformer=False, # ==============TODO:orig============== - # multiplication_moe_in_transformer=True, # =======TODO: moe8======= + # multiplication_moe_in_transformer=False, # ==============TODO:orig============== + multiplication_moe_in_transformer=True, # =======TODO: moe8======= n_shared_experts=1, num_experts_per_tok=1, num_experts_of_moe_in_transformer=8, @@ -337,7 +337,7 @@ def create_env_manager(): batch_sizes, grad_acc_steps = compute_batch_config(env_id_list, effective_batch_size) total_batch_size = effective_batch_size # 当前无效 - + num_unroll_steps = 10 # infer_context_length = 4 infer_context_length = 5 # ==============TODO============== @@ -350,7 +350,9 @@ def create_env_manager(): # ======== TODO: only for debug ======== env_id_list = [ - 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4' + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + # 'SeaquestNoFrameskip-v4' ] num_layers = 1 # ==============TODO============== collector_env_num = 2 @@ -363,11 +365,14 @@ def create_env_manager(): infer_context_length = 2 batch_sizes = [2 for _ in range(len(env_id_list))] total_batch_size = 2*len(env_id_list) + + # ===========button from tangjia=========== + import torch.distributed as dist - for seed in [0,1]: + for seed in [100]: configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_sizes, num_unroll_steps, infer_context_length, norm_type, seed, buffer_reanalyze_freq, reanalyze_batch_size, reanalyze_partition,