-
Notifications
You must be signed in to change notification settings - Fork 12
Add human log likelihood regularizer #75
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Just testing greptile on things to develop a feel for it |
Greptile OverviewGreptile SummaryThis PR adds human log likelihood regularization to encourage the policy to mimic human driving behavior from expert trajectories. The implementation uses an inverse bicycle model to infer continuous steering and acceleration actions from trajectory data, stores them in binary format, and samples them during training to compute an auxiliary loss term. Key Changes:
Issues Found:
Confidence Score: 3/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant Init as Environment Init
participant C as C/C++ (drive.h)
participant Py as Python (drive.py)
participant Train as Training Loop (pufferl.py)
participant Disk as Disk Storage
Note over Init,Disk: Initialization Phase
Init->>Py: __init__(bptt_horizon, max_expert_samples)
Py->>Py: _save_expert_data()
Py->>C: vec_collect_expert_data()
C->>C: c_reset() for each env
loop t=0 to TRAJECTORY_LENGTH-1
C->>C: Read expert_accel[t], expert_steering[t]
C->>C: Discretize continuous actions
C->>C: Store observations
C->>C: c_step() to advance env
end
C->>C: c_reset() to restore state
C-->>Py: Return discrete_actions, continuous_actions, observations
Py->>Py: Create bptt_horizon length sequences
Py->>Disk: Save expert_actions_discrete.pt
Py->>Disk: Save expert_actions_continuous.pt
Py->>Disk: Save expert_observations.pt
Note over Train,Disk: Training Phase (per minibatch)
Train->>Py: sample_expert_data(n_samples)
Py->>Disk: Load expert data files (CPU)
Py->>Py: Random sample n_samples
Py-->>Train: Return actions, observations
Train->>Train: Move to GPU
Train->>Train: policy(human_obs, human_state)
Train->>Train: sample_logits(logits, human_actions)
Train->>Train: Compute human_log_prob
Train->>Train: loss = pg_loss + vf_loss - ent_loss - human_ll_coef * human_loss
Train->>Train: Backward pass & optimize
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 3 comments
| # Add log likelihood loss of human actions under current policy. | ||
| # 1: Sample a batch of human actions and observations from dataset | ||
| # Shape: [n_samples, bptt_horizon, feature_dim] | ||
| discrete_human_actions, continuous_human_actions, human_observations = ( | ||
| self.vecenv.driver_env.sample_expert_data(n_samples=config["human_samples"], return_both=True) | ||
| ) | ||
| discrete_human_actions = discrete_human_actions.to(device) | ||
| continuous_human_actions = continuous_human_actions.to(device) | ||
| human_observations = human_observations.to(device) | ||
|
|
||
| # Use helper function to compute realism metrics | ||
| realism_metrics = self.vecenv.driver_env.compute_realism_metrics( | ||
| discrete_human_actions, continuous_human_actions | ||
| ) | ||
| self.realism.update(realism_metrics) | ||
|
|
||
| # Select appropriate action type for training | ||
| use_continuous = self.vecenv.driver_env._action_type_flag == 1 | ||
| human_actions = continuous_human_actions if use_continuous else discrete_human_actions | ||
| human_observations = human_observations.to(device) | ||
|
|
||
| # 2: Compute the log-likelihood of human actions under the current policy, | ||
| # given the corresponding human observations. A higher likelihood indicates | ||
| # that the policy behaves more like a human under the same observations. | ||
| human_state = dict( | ||
| action=human_actions, | ||
| lstm_h=None, | ||
| lstm_c=None, | ||
| ) | ||
|
|
||
| human_logits, _ = self.policy(human_observations, human_state) | ||
|
|
||
| _, human_log_prob, human_entropy = pufferlib.pytorch.sample_logits( | ||
| logits=human_logits, action=human_actions | ||
| ) | ||
|
|
||
| self.realism["human_log_prob"] = human_log_prob.mean().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: human log likelihood computation happens every minibatch but samples from the same full dataset each time - this creates unnecessary disk I/O overhead. Consider sampling once per epoch outside the minibatch loop and reusing those samples across all minibatches within that epoch
Prompt To Fix With AI
This is a comment left during a code review.
Path: pufferlib/pufferl.py
Line: 416:452
Comment:
**style:** human log likelihood computation happens every minibatch but samples from the same full dataset each time - this creates unnecessary disk I/O overhead. Consider sampling once per epoch outside the minibatch loop and reusing those samples across all minibatches within that epoch
How can I resolve this? If you propose a fix, please make it concise.
Goal
Given a dataset of human "expert" trajectories$(o_t, a_t)$ , we aim to compute the likelihood of the expert’s actions under the current policy and use this as an auxiliary learning signal.
Implementation
Dataset
Use the inverse bicycle model to infer the human actions taken for every step.
At initialization
The method
_save_expert_data()savesmax_expert_samplesper env to disk. We stack actions based onbptt_horizonso that the resulting tuples are of shape(max_expert_samples, bptt_horizon, action/obs_dim).Under the hood - this method uses
c_collect_expert_data()indrive.hto step the agents with their continuous inferred human actions and collect the corresponding observations. We return the discretized actions if needed.Training inner loop
Do:
human_samplestuplesWandb integration
realism/section with the human action distributions. This is done mostly as a sanity check. The section can be extended with other realism-related metrics.