diff --git a/baselines/ppo/config/ppo_base_puffer.yaml b/baselines/ppo/config/ppo_base_puffer.yaml index 9f985667a..8f8538847 100644 --- a/baselines/ppo/config/ppo_base_puffer.yaml +++ b/baselines/ppo/config/ppo_base_puffer.yaml @@ -38,13 +38,13 @@ environment: # Overrides default environment configs (see pygpudrive/env/config. wandb: entity: "" - project: "gpudrive" - group: "test" + project: "adv_filter" + group: "testing" mode: "online" # Options: online, offline, disabled tags: ["ppo", "ff"] train: - exp_id: PPO # Set dynamically in the script if needed + exp_id: adv_filter # Set dynamically in the script if needed seed: 42 cpu_offload: false device: "cuda" # Dynamically set to cuda if available, else cpu @@ -63,7 +63,7 @@ train: torch_deterministic: false total_timesteps: 1_000_000_000 batch_size: 131_072 - minibatch_size: 8192 + num_minibatches: 16 learning_rate: 3e-4 anneal_lr: false gamma: 0.99 @@ -78,6 +78,10 @@ train: max_grad_norm: 0.5 target_kl: null log_window: 1000 + # Advantage filtering + apply_advantage_filter: true + initial_th_factor: 0.01 + beta: 0.25 # # # Network # # # network: diff --git a/baselines/ppo/ppo_pufferlib.py b/baselines/ppo/ppo_pufferlib.py index 587fb96c9..6f0e2a550 100644 --- a/baselines/ppo/ppo_pufferlib.py +++ b/baselines/ppo/ppo_pufferlib.py @@ -144,7 +144,7 @@ def sweep(args, project="PPO", sweep_name="my_sweep"): "max": 1e-1, }, "batch_size": {"values": [512, 1024, 2048]}, - "minibatch_size": {"values": [128, 256, 512]}, + "num_minibatches": {"values": [4, 8, 16]}, }, ), project=project, @@ -186,9 +186,13 @@ def run( ent_coef: Annotated[Optional[float], typer.Option(help="Entropy coefficient")] = None, update_epochs: Annotated[Optional[int], typer.Option(help="The number of epochs for updating the policy")] = None, batch_size: Annotated[Optional[int], typer.Option(help="The batch size for training")] = None, - minibatch_size: Annotated[Optional[int], typer.Option(help="The minibatch size for training")] = None, + num_minibatches: Annotated[Optional[int], typer.Option(help="The number of minibatches for training")] = None, gamma: Annotated[Optional[float], typer.Option(help="The discount factor for rewards")] = None, vf_coef: Annotated[Optional[float], typer.Option(help="Weight for vf_loss")] = None, + # Advantage filtering + apply_advantage_filter: Annotated[Optional[int], typer.Option(help="Whether to use advantage filter; 0 or 1")] = None, + initial_th_factor: Annotated[Optional[float], typer.Option(help="Initial threshold factor for training")] = None, + beta: Annotated[Optional[float], typer.Option(help="Beta parameter for training")] = None, # Wandb logging options project: Annotated[Optional[str], typer.Option(help="WandB project name")] = None, entity: Annotated[Optional[str], typer.Option(help="WandB entity name")] = None, @@ -238,10 +242,13 @@ def run( "ent_coef": ent_coef, "update_epochs": update_epochs, "batch_size": batch_size, - "minibatch_size": minibatch_size, + "num_minibatches": num_minibatches, "render": None if render is None else bool(render), "gamma": gamma, "vf_coef": vf_coef, + "apply_advantage_filter": apply_advantage_filter, + "initial_th_factor": initial_th_factor, + "beta": beta, } config.train.update( {k: v for k, v in train_config.items() if v is not None} diff --git a/gpudrive/integrations/puffer/ppo.py b/gpudrive/integrations/puffer/ppo.py index bdc65ed45..9c9bb8b4e 100644 --- a/gpudrive/integrations/puffer/ppo.py +++ b/gpudrive/integrations/puffer/ppo.py @@ -12,6 +12,7 @@ import random import psutil import time +import warnings from threading import Thread from collections import defaultdict, deque @@ -33,6 +34,60 @@ from gpudrive.integrations.puffer.logging import print_dashboard, abbreviate +class AdvantageFilter: + """ + Advantage filtering class to filter transitions based on advantage magnitude. + + This implementation is based on Algorithm 1 in "Robust Autonomy Emerges from Self-Play" + (https://arxiv.org/abs/2502.03349). The key idea is to discard transitions with + low-magnitude advantages to focus training on the most informative samples. + + The filtering threshold η is set to a percentage of the maximum advantage + magnitude observed so far, making it scale-invariant to reward magnitudes. + """ + + def __init__(self, beta=0.25, initial_th_factor=0.01): + """ + Args: + beta: EWMA decay factor for tracking maximum advantage + initial_th_factor: Filter threshold as a percentage of max advantage + """ + self.beta = beta + self.threshold_factor = initial_th_factor + self.max_advantage_ewma = None + + def filter(self, advantages_np): + """ + Filter transitions based on advantage magnitude. + + Args: + advantages_np: Numpy array of advantages + + Returns: + mask: Boolean mask where True indicates transitions to keep + threshold: Current filtering threshold (η) + """ + # Get new max advantage + max_advantage = float(np.max(np.abs(advantages_np))) + + # Update the EWMA of max advantage + if self.max_advantage_ewma is None: + self.max_advantage_ewma = max_advantage + else: + self.max_advantage_ewma = ( + self.beta * max_advantage + + (1 - self.beta) * self.max_advantage_ewma + ) + + # Update filtering threshold + threshold = self.threshold_factor * self.max_advantage_ewma + + # Create mask of transitions to keep (where |advantage| >= threshold) + mask = np.abs(advantages_np) >= threshold + + return mask, threshold + + def create(config, vecenv, policy, optimizer=None, wandb=None): seed_everything(config.seed, config.torch_deterministic) profile = Profile() @@ -62,7 +117,7 @@ def create(config, vecenv, policy, optimizer=None, wandb=None): experience = Experience( config.batch_size, config.bptt_horizon, - config.minibatch_size, + config.num_minibatches, obs_shape, obs_dtype, atn_shape, @@ -235,14 +290,58 @@ def train(data): losses = data.losses with profile.train_misc: + # Get the sorted indices for training data idxs = experience.sort_training_data() dones_np = experience.dones_np[idxs] values_np = experience.values_np[idxs] rewards_np = experience.rewards_np[idxs] + + # Compute GAE advantages advantages_np = compute_gae( dones_np, values_np, rewards_np, config.gamma, config.gae_lambda ) - experience.flatten_batch(advantages_np) + + filter_mask = None + if config.apply_advantage_filter: + if config.bptt_horizon > 1: + raise ValueError( + "Advantage filtering cannot be used with LSTM (bptt_horizon > 1)" + ) + + # Initialize the advantage filter if not already created + if not hasattr(data, "advantage_filter"): + data.advantage_filter = AdvantageFilter( + beta=config.beta, + initial_th_factor=config.initial_th_factor, + ) + + # Get mask of transitions to keep based on advantage magnitude + filter_mask, threshold = data.advantage_filter.filter( + advantages_np + ) + + # Log stats + num_total = len(advantages_np) + num_kept = filter_mask.sum() + percent_kept = 100 * num_kept / num_total if num_total > 0 else 0 + + data.filtering_stats = { + "advantage_filtering/threshold (η)": threshold, + "advantage_filtering/percent_kept": percent_kept, + "advantage_filtering/max_advantage": float( + np.max(np.abs(advantages_np)) + ), + "advantage_filtering/num_kept": int(num_kept), + "advantage_filtering/total": int(num_total), + } + + data.msg = ( + f"Advantage filtering: kept {num_kept}/{num_total} transitions " + f"({percent_kept:.1f}%, threshold (η)={threshold:.4f})" + ) + + # Prepare batch of transitions for model updating + experience.flatten_batch(advantages_np, filter_mask) # Optimizing the policy and value network num_update_iters = config.update_epochs * experience.num_minibatches @@ -381,22 +480,27 @@ def train(data): and data.global_step > 0 and time.perf_counter() - data.last_log_time > 3.0 ): - data.last_log_time = time.perf_counter() - data.wandb.log( - { - "performance/controlled_agent_sps": profile.controlled_agent_sps, - "performance/controlled_agent_sps_env": profile.controlled_agent_sps_env, - "performance/pad_agent_sps": profile.pad_agent_sps, - "performance/pad_agent_sps_env": profile.pad_agent_sps_env, - "global_step": data.global_step, - "performance/epoch": data.epoch, - "performance/uptime": profile.uptime, - "train/learning_rate": data.optimizer.param_groups[0]["lr"], - **{f"metrics/{k}": v for k, v in data.stats.items()}, - **{f"train/{k}": v for k, v in data.losses.items()}, - } - ) + + # Create log dictionary with existing metrics + log_dict = { + "performance/controlled_agent_sps": profile.controlled_agent_sps, + "performance/controlled_agent_sps_env": profile.controlled_agent_sps_env, + "performance/pad_agent_sps": profile.pad_agent_sps, + "performance/pad_agent_sps_env": profile.pad_agent_sps_env, + "global_step": data.global_step, + "performance/epoch": data.epoch, + "performance/uptime": profile.uptime, + "train/learning_rate": data.optimizer.param_groups[0]["lr"], + **{f"metrics/{k}": v for k, v in data.stats.items()}, + **{f"train/{k}": v for k, v in data.losses.items()}, + } + + # Add advantage filtering metrics if available + if hasattr(data, 'filtering_stats') and data.filtering_stats: + log_dict.update(data.filtering_stats) + + data.wandb.log(log_dict) if bool(data.stats): data.wandb.log({ @@ -404,7 +508,6 @@ def train(data): }) # fmt: on - if data.epoch % config.checkpoint_interval == 0 or done_training: save_checkpoint(data) data.msg = f"Checkpoint saved at update {data.epoch}" @@ -526,7 +629,6 @@ def make_losses(): explained_variance=0, ) - class Experience: """Flat tensor storage (buffer) and array views for faster indexing.""" @@ -534,7 +636,7 @@ def __init__( self, batch_size, bptt_horizon, - minibatch_size, + num_minibatches, obs_shape, obs_dtype, atn_shape, @@ -543,8 +645,8 @@ def __init__( lstm=None, lstm_total_agents=0, ): - if minibatch_size is None: - minibatch_size = batch_size + if num_minibatches is None: + num_minibatches = 1 obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_dtype] pin = device == "cuda" and cpu_offload @@ -579,8 +681,8 @@ def __init__( self.lstm_h = torch.zeros(shape).to(device) self.lstm_c = torch.zeros(shape).to(device) - num_minibatches = batch_size / minibatch_size self.num_minibatches = int(num_minibatches) + minibatch_size = batch_size // num_minibatches if self.num_minibatches != num_minibatches: raise ValueError("batch_size must be divisible by minibatch_size") @@ -601,9 +703,11 @@ def __init__( @property def full(self): + """Check if the buffer is full.""" return self.ptr >= self.batch_size def store(self, obs, value, action, logprob, reward, done, env_id, mask): + """Store a batch of transitions in the buffer.""" # Mask learner and Ensure indices do not exceed batch size ptr = self.ptr indices = torch.where(mask)[0].cpu().numpy()[: self.batch_size - ptr] @@ -643,27 +747,103 @@ def sort_training_data(self): self.step = 0 return idxs - def flatten_batch(self, advantages_np): - advantages = torch.from_numpy(advantages_np).to(self.device) - b_idxs, b_flat = self.b_idxs, self.b_idxs_flat - self.b_actions = self.actions.to(self.device, non_blocking=True) - self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) - self.b_dones = self.dones.to(self.device, non_blocking=True) - self.b_values = self.values.to(self.device, non_blocking=True) - self.b_advantages = ( - advantages.reshape( - self.minibatch_rows, self.num_minibatches, self.bptt_horizon + def flatten_batch(self, advantages_np, filter_mask=None): + """Prepare the batch of transitions for model updating.""" + + if filter_mask is not None: + + # Get the indices of transitions to keep + kept_indices = np.nonzero(filter_mask)[0] + total_kept = len(kept_indices) + + # Determine how many transitions per minibatch (floor division) + transitions_per_mb = total_kept // self.num_minibatches + + # We need at least one transition per minibatch + if transitions_per_mb < 32: + transitions_per_mb = 64 + + warnings.warn(f"Low adv. filtering retention rate: Only kept {len(kept_indices)} / {len(advantages_np)} transitions ({transitions_per_mb} per minibatch) \n Consider adjusting the advantage threshold factor or increase the batch_size.", UserWarning) + + # If we don't have enough, sample with replacement + if total_kept < self.num_minibatches: + kept_indices = np.random.choice(kept_indices, self.num_minibatches, replace=True) + total_kept = len(kept_indices) + + # Calculate total transitions to use (divisible by num_minibatches) + transitions_to_use = transitions_per_mb * self.num_minibatches + + np.random.shuffle(kept_indices) + kept_indices = kept_indices[:transitions_to_use] + filtered_idxs = kept_indices.copy() + + # Reshape to (minibatch_rows, num_minibatches, bptt_horizon) + minibatch_rows_filtered = transitions_to_use // (self.num_minibatches * self.bptt_horizon) + filtered_idxs = filtered_idxs.reshape( + minibatch_rows_filtered, + self.num_minibatches, + self.bptt_horizon ) - .transpose(0, 1) - .reshape(self.num_minibatches, self.minibatch_size) - ) - self.returns_np = advantages_np + self.values_np - self.b_obs = self.obs[self.b_idxs_obs] - self.b_actions = self.b_actions[b_idxs].contiguous() - self.b_logprobs = self.b_logprobs[b_idxs] - self.b_dones = self.b_dones[b_idxs] - self.b_values = self.b_values[b_flat] - self.b_returns = self.b_advantages + self.b_values + filtered_idxs = np.transpose(filtered_idxs, (1, 0, 2)) + + # Update minibatch indices + self.b_idxs_obs = torch.as_tensor(filtered_idxs).to(self.obs.device).long() + self.b_idxs = self.b_idxs_obs.to(self.device) + self.b_idxs_flat = self.b_idxs.reshape(self.num_minibatches, -1) + + # Get advantages for the filtered transitions + advantages = torch.from_numpy(advantages_np).to(self.device) + + # The rest of the processing is similar to the original code + b_idxs, b_flat = self.b_idxs, self.b_idxs_flat + self.b_actions = self.actions.to(self.device, non_blocking=True) + self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) + self.b_dones = self.dones.to(self.device, non_blocking=True) + self.b_values = self.values.to(self.device, non_blocking=True) + + # Reshape advantages to match the filtered structure + filtered_advantages = advantages[kept_indices[:transitions_to_use]] + self.b_advantages = filtered_advantages.reshape( + self.num_minibatches, -1 + ) + + # Compute returns + self.returns_np = advantages_np + self.values_np + + # Get observations, actions, etc. based on filtered indices + self.b_obs = self.obs[self.b_idxs_obs] + self.b_actions = self.b_actions[b_idxs].contiguous() + self.b_logprobs = self.b_logprobs[b_idxs] + self.b_dones = self.b_dones[b_idxs] + self.b_values = self.b_values[b_flat] + self.b_returns = self.b_advantages + self.b_values + + else: + # Original implementation for when no filtering is applied + advantages = torch.from_numpy(advantages_np).to(self.device) + + # Get the respective indices for all minibatches + b_idxs, b_flat = self.b_idxs, self.b_idxs_flat + self.b_actions = self.actions.to(self.device, non_blocking=True) + self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) + self.b_dones = self.dones.to(self.device, non_blocking=True) + self.b_values = self.values.to(self.device, non_blocking=True) + self.b_advantages = ( + advantages.reshape( + self.minibatch_rows, self.num_minibatches, self.bptt_horizon + ) + .transpose(0, 1) + .reshape(self.num_minibatches, self.minibatch_size) + ) + + # Re-order the transitions based on the sorted indices + self.returns_np = advantages_np + self.values_np + self.b_obs = self.obs[self.b_idxs_obs] + self.b_actions = self.b_actions[b_idxs].contiguous() + self.b_logprobs = self.b_logprobs[b_idxs] + self.b_dones = self.b_dones[b_idxs] + self.b_values = self.b_values[b_flat] + self.b_returns = self.b_advantages + self.b_values class Utilization(Thread): @@ -716,6 +896,7 @@ def save_checkpoint(data, save_checkpoint_to_wandb=True): "action_dim": data.uncompiled_policy.action_dim, "exp_id": config.exp_id, "num_params": config.network["num_parameters"], + "config": data.config.to_dict(), } torch.save(state, model_path) diff --git a/gpudrive/utils/generate_sbatch.py b/gpudrive/utils/generate_sbatch.py index 95850cca3..0118e5359 100644 --- a/gpudrive/utils/generate_sbatch.py +++ b/gpudrive/utils/generate_sbatch.py @@ -252,21 +252,21 @@ def save_script(filename, file_path, fields, params, param_order=None): "memory": 70, "job_name": group, } - + hyperparams = { - "group": [group], # Group name + "group": [group], # Group name "num_worlds": [800], - "resample_scenes": [1], # Yes + "resample_scenes": [1], # Yes "k_unique_scenes": [800], "resample_interval": [5_000_000], "total_timesteps": [4_000_000_000], "resample_dataset_size": [10_000], "batch_size": [524288], - "minibatch_size": [16384], + "num_minibatches": [16], "update_epochs": [4], "ent_coef": [0.001, 0.003, 0.0001], "render": [0], - #"seed": [42, 3], + # "seed": [42, 3], } save_script( @@ -285,7 +285,7 @@ def save_script(filename, file_path, fields, params, param_order=None): # "total_timesteps": [3_000_000_000], # "resample_dataset_size": [1000], # "batch_size": [262_144, 524_288], - # "minibatch_size": [16_384], + # "num_minibatches": [16_384], # "update_epochs": [2, 4, 5], # "ent_coef": [0.0001, 0.001, 0.003], # "learning_rate": [1e-4, 3e-4], @@ -299,6 +299,3 @@ def save_script(filename, file_path, fields, params, param_order=None): # fields=fields, # params=hyperparams, # ) - - -