Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/experimental/config/eval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ test_dataset_size: 10_000 # Number of test scenarios to evaluate on

# Environment settings
train_dir: data/processed/training
test_dir: data/processed/validation
test_dir: data/processed/validation
file_prefix: null # Controls the file prefix used when searching files in SceneDataLoader, defaults to tfrecord (WOMD data)

num_worlds: 50 # Number of parallel environments for evaluation
max_controlled_agents: 64 # Maximum number of agents controlled by the model.
Expand All @@ -26,6 +27,7 @@ obs_radius: 50.0 # Visibility radius of the agents
init_roadgraph: False
render_3d: True

action_type: "discrete"
# Number of discretizations in the action space
# Note: Make sure that this equals the discretizations that the policy
# has been trained with
Expand Down
37 changes: 37 additions & 0 deletions examples/experimental/config/expert_replay_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
res_path: examples/experimental/dataframes # Store dataframes here
test_dataset_size: 300 # Number of test scenarios to evaluate on

# Environment settings
train_dir: data/processed/training
test_dir: data/processed/validation
file_prefix: nuplan

num_worlds: 100 # Number of parallel environments for evaluation
max_controlled_agents: 64 # Maximum number of agents controlled by the model.
ego_state: true
road_map_obs: true
partner_obs: true
norm_obs: true
remove_non_vehicles: true # If false, all agents are included (vehicles, pedestrians, cyclists)
lidar_obs: false # NOTE: Setting this to true currently turns of the other observation types
reward_type: "weighted_combination"
collision_weight: -0.75
off_road_weight: -0.75
goal_achieved_weight: 1.0
dynamics_model: "delta_local"
collision_behavior: "ignore" # Options: "remove", "stop"
dist_to_goal_threshold: 2.0
polyline_reduction_threshold: 0.1 # Rate at which to sample points from the polyline (0 is use all closest points, 1 maximum sparsity), needs to be balanced with kMaxAgentMapObservationsCount
sampling_seed: 42 # If given, the set of scenes to sample from will be deterministic, if None, the set of scenes will be random
obs_radius: 50.0 # Visibility radius of the agents
init_roadgraph: False
render_3d: True

action_type: "continuous"
# Number of discretizations in the action space
# Note: Make sure that this equals the discretizations that the policy
# has been trained with
action_space_steer_disc: 13
action_space_accel_disc: 7

device: "cuda" # Options: "cpu", "cuda"
33 changes: 21 additions & 12 deletions examples/experimental/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def __call__(self, obs, deterministic=False):
)
return random_action, None, None, None

class ExpertReplayPolicy:
def __init__(self):
pass

def load_policy(path_to_cpt, model_name, device, env=None):
"""Load a policy from a given path."""

# Load the saved checkpoint
if model_name == "random_baseline":
return RandomPolicy(env.action_space.n)

if model_name == "expert_replay":
return ExpertReplayPolicy()
else: # Load a trained model
saved_cpt = torch.load(
f=f"{path_to_cpt}/{model_name}.pt",
Expand Down Expand Up @@ -110,22 +114,26 @@ def rollout(

control_mask = env.cont_agent_mask
live_agent_mask = control_mask.clone()

expert_actions, _, _, _ = env.get_expert_actions()

for time_step in range(episode_len):

print(f't: {time_step}')

# Get actions for active agents
if live_agent_mask.any():
action, _, _, _ = policy(
next_obs[live_agent_mask], deterministic=deterministic
)

# Insert actions into a template
action_template = torch.zeros(
(num_worlds, max_agent_count), dtype=torch.int64, device=device
)
action_template[live_agent_mask] = action.to(device)
if isinstance(policy, ExpertReplayPolicy):
action_template = expert_actions[:, :, time_step, :]
else:
action, _, _, _ = policy(
next_obs[live_agent_mask], deterministic=deterministic
)

# Insert actions into a template
action_template = torch.zeros(
(num_worlds, max_agent_count), dtype=torch.int64, device=device
)
action_template[live_agent_mask] = action.to(device)

# Step the environment
env.step_dynamics(action_template)
Expand Down Expand Up @@ -274,7 +282,8 @@ def make_env(config, train_loader, render_3d=False):
data_loader=train_loader,
max_cont_agents=config.max_controlled_agents,
device=config.device,
render_config=render_config
render_config=render_config,
action_type=config.action_type,
)

return env
Expand Down
2 changes: 2 additions & 0 deletions examples/experimental/get_model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def set_seed(seed: int):
else 1000,
sample_with_replacement=False,
shuffle=False,
file_prefix=eval_config.file_predix
)

test_loader = SceneDataLoader(
Expand All @@ -74,6 +75,7 @@ def set_seed(seed: int):
else 1000,
sample_with_replacement=False,
shuffle=True,
file_prefix=eval_config.file_predix
)

# Rollouts
Expand Down
11 changes: 6 additions & 5 deletions gpudrive/visualize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
from matplotlib.colors import ListedColormap
from jaxlib.xla_extension import ArrayImpl
import numpy as np
import madrona_gpudrive
from gpudrive.visualize import utils
Expand Down Expand Up @@ -78,10 +77,12 @@ def initialize_static_scenario_data(self, controlled_agent_mask):
)
self.controlled_agent_mask = controlled_agent_mask

if isinstance(controlled_agent_mask, ArrayImpl):
self.controlled_agent_mask = torch.from_numpy(
np.array(controlled_agent_mask)
)
if self.backend == "jax":
from jaxlib.xla_extension import ArrayImpl
if isinstance(controlled_agent_mask, ArrayImpl):
self.controlled_agent_mask = torch.from_numpy(
np.array(controlled_agent_mask)
)

self.controlled_agent_mask = self.controlled_agent_mask.to(self.device)

Expand Down
Loading