Skip to content

Commit fecb769

Browse files
committed
added reward/entropy/discount conditioning
1 parent 595734b commit fecb769

File tree

10 files changed

+385
-32
lines changed

10 files changed

+385
-32
lines changed

pufferlib/config/ocean/drive.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ init_steps = 0 # Determines which step of the trajectory to initialize the agent
3636
control_all_agents = False # this should be set to false unless you want to specifically want to override and control expert marked vehicles
3737
num_policy_controlled_agents = -1 # note: if you add this you likely need to set num_agents to a smaller number
3838
deterministic_agent_selection = False # if this is true it overrides vehicles marked as expert to be policy controlled
39+
condition_type = "none" # Options: "none", "reward", "entropy", "discount", "all"
3940

4041
[train]
4142
total_timesteps = 2_000_000_000

pufferlib/extensions/cuda/pufferlib.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ __host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards,
2020
}
2121

2222
void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
23-
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
24-
int num_steps, int horizon) {
23+
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
24+
torch::Tensor gammas, int num_steps, int horizon) {
2525

2626
// Validate input tensors
2727
torch::Device device = values.device();
@@ -35,24 +35,30 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
3535
t.contiguous();
3636
}
3737
}
38+
// Validate gammas tensor
39+
TORCH_CHECK(gammas.dim() == 1, "Gammas must be 1D");
40+
TORCH_CHECK(gammas.size(0) == num_steps, "Gammas size must match num_steps");
41+
TORCH_CHECK(gammas.dtype() == torch::kFloat32, "Gammas must be float32");
42+
TORCH_CHECK(gammas.is_cuda(), "Gammas must be on GPU");
43+
TORCH_CHECK(gammas.is_contiguous(), "Gammas must be contiguous");
3844
}
3945

4046
// [num_steps, horizon]
4147
__global__ void puff_advantage_kernel(float* values, float* rewards,
42-
float* dones, float* importance, float* advantages, float gamma,
48+
float* dones, float* importance, float* advantages, float* gammas,
4349
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
4450
int row = blockIdx.x*blockDim.x + threadIdx.x;
4551
if (row >= num_steps) {
4652
return;
4753
}
4854
int offset = row*horizon;
4955
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
50-
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
56+
importance + offset, advantages + offset, gammas[row], lambda, rho_clip, c_clip, horizon);
5157
}
5258

5359
void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
5460
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
55-
double gamma, double lambda, double rho_clip, double c_clip) {
61+
torch::Tensor gammas, double lambda, double rho_clip, double c_clip) {
5662
int num_steps = values.size(0);
5763
int horizon = values.size(1);
5864
vtrace_check_cuda(values, rewards, dones, importance, advantages, num_steps, horizon);
@@ -67,7 +73,7 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
6773
dones.data_ptr<float>(),
6874
importance.data_ptr<float>(),
6975
advantages.data_ptr<float>(),
70-
gamma,
76+
gammas.data_ptr<float>(),
7177
lambda,
7278
rho_clip,
7379
c_clip,

pufferlib/extensions/pufferlib.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void puff_advantage_row(float* values, float* rewards, float* dones,
4242

4343
void vtrace_check(torch::Tensor values, torch::Tensor rewards,
4444
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
45-
int num_steps, int horizon) {
45+
torch::Tensor gammas, int num_steps, int horizon) {
4646

4747
// Validate input tensors
4848
torch::Device device = values.device();
@@ -56,36 +56,42 @@ void vtrace_check(torch::Tensor values, torch::Tensor rewards,
5656
t.contiguous();
5757
}
5858
}
59+
// Validate gammas tensor
60+
TORCH_CHECK(gammas.dim() == 1, "Gammas must be 1D");
61+
TORCH_CHECK(gammas.size(0) == num_steps, "Gammas size must match num_steps");
62+
TORCH_CHECK(gammas.dtype() == torch::kFloat32, "Gammas must be float32");
63+
TORCH_CHECK(gammas.is_contiguous(), "Gammas must be contiguous");
5964
}
6065

6166

6267
// [num_steps, horizon]
6368
void puff_advantage(float* values, float* rewards, float* dones, float* importance,
64-
float* advantages, float gamma, float lambda, float rho_clip, float c_clip,
69+
float* advantages, float* gammas, float lambda, float rho_clip, float c_clip,
6570
int num_steps, const int horizon){
66-
for (int offset = 0; offset < num_steps*horizon; offset+=horizon) {
71+
for (int row = 0; row < num_steps; row++) {
72+
int offset = row * horizon;
6773
puff_advantage_row(values + offset, rewards + offset,
6874
dones + offset, importance + offset, advantages + offset,
69-
gamma, lambda, rho_clip, c_clip, horizon
75+
gammas[row], lambda, rho_clip, c_clip, horizon
7076
);
7177
}
7278
}
7379

7480

7581
void compute_puff_advantage_cpu(torch::Tensor values, torch::Tensor rewards,
7682
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
77-
double gamma, double lambda, double rho_clip, double c_clip) {
83+
torch::Tensor gammas, double lambda, double rho_clip, double c_clip) {
7884
int num_steps = values.size(0);
7985
int horizon = values.size(1);
80-
vtrace_check(values, rewards, dones, importance, advantages, num_steps, horizon);
86+
vtrace_check(values, rewards, dones, importance, advantages, gammas, num_steps, horizon);
8187
puff_advantage(values.data_ptr<float>(), rewards.data_ptr<float>(),
8288
dones.data_ptr<float>(), importance.data_ptr<float>(), advantages.data_ptr<float>(),
83-
gamma, lambda, rho_clip, c_clip, num_steps, horizon
89+
gammas.data_ptr<float>(), lambda, rho_clip, c_clip, num_steps, horizon
8490
);
8591
}
8692

8793
TORCH_LIBRARY(pufferlib, m) {
88-
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, float gamma, float lambda, float rho_clip, float c_clip) -> ()");
94+
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, Tensor gammas, float lambda, float rho_clip, float c_clip) -> ()");
8995
}
9096

9197
TORCH_LIBRARY_IMPL(pufferlib, CPU, m) {

pufferlib/ocean/drive/binding.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,22 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
182182
env->control_all_agents = unpack(kwargs, "control_all_agents");
183183
env->deterministic_agent_selection = unpack(kwargs, "deterministic_agent_selection");
184184
env->control_non_vehicles = (int)unpack(kwargs, "control_non_vehicles");
185+
186+
// Conditioning parameters
187+
env->use_rc = (bool)unpack(kwargs, "use_rc");
188+
env->use_ec = (bool)unpack(kwargs, "use_ec");
189+
env->use_dc = (bool)unpack(kwargs, "use_dc");
190+
env->collision_weight_lb = (float)unpack(kwargs, "collision_weight_lb");
191+
env->collision_weight_ub = (float)unpack(kwargs, "collision_weight_ub");
192+
env->offroad_weight_lb = (float)unpack(kwargs, "offroad_weight_lb");
193+
env->offroad_weight_ub = (float)unpack(kwargs, "offroad_weight_ub");
194+
env->goal_weight_lb = (float)unpack(kwargs, "goal_weight_lb");
195+
env->goal_weight_ub = (float)unpack(kwargs, "goal_weight_ub");
196+
env->entropy_weight_lb = (float)unpack(kwargs, "entropy_weight_lb");
197+
env->entropy_weight_ub = (float)unpack(kwargs, "entropy_weight_ub");
198+
env->discount_weight_lb = (float)unpack(kwargs, "discount_weight_lb");
199+
env->discount_weight_ub = (float)unpack(kwargs, "discount_weight_ub");
200+
185201
int map_id = unpack(kwargs, "map_id");
186202
int max_agents = unpack(kwargs, "max_agents");
187203
int init_steps = unpack(kwargs, "init_steps");

pufferlib/ocean/drive/drive.h

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ struct Log {
119119
float active_agent_count;
120120
float expert_static_car_count;
121121
float static_car_count;
122+
// Conditioning metrics
123+
float avg_collision_weight;
124+
float avg_offroad_weight;
125+
float avg_goal_weight;
126+
float avg_entropy_weight;
127+
float avg_discount_weight;
122128
};
123129

124130
typedef struct Entity Entity;
@@ -287,6 +293,27 @@ struct Drive {
287293
char* ini_file;
288294
int scenario_length;
289295
int control_non_vehicles;
296+
// Reward conditioning
297+
bool use_rc;
298+
float collision_weight_lb;
299+
float collision_weight_ub;
300+
float offroad_weight_lb;
301+
float offroad_weight_ub;
302+
float goal_weight_lb;
303+
float goal_weight_ub;
304+
float* collision_weights;
305+
float* offroad_weights;
306+
float* goal_weights;
307+
// Entropy conditioning
308+
bool use_ec;
309+
float entropy_weight_lb;
310+
float entropy_weight_ub;
311+
float* entropy_weights;
312+
// Discount conditioning
313+
bool use_dc;
314+
float discount_weight_lb;
315+
float discount_weight_ub;
316+
float* discount_weights;
290317
};
291318

292319
typedef struct {
@@ -1565,6 +1592,18 @@ void init(Drive* env){
15651592
set_start_position(env);
15661593
init_goal_positions(env);
15671594
env->logs = (Log*)calloc(env->active_agent_count, sizeof(Log));
1595+
1596+
if (env->use_rc) {
1597+
env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1598+
env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1599+
env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1600+
}
1601+
if (env->use_ec) {
1602+
env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1603+
}
1604+
if (env->use_dc) {
1605+
env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1606+
}
15681607
}
15691608

15701609
void c_close(Drive* env){
@@ -1594,6 +1633,18 @@ void c_close(Drive* env){
15941633
freeTopologyGraph(env->topology_graph);
15951634
// free(env->map_name);
15961635
free(env->ini_file);
1636+
1637+
if (env->use_rc) {
1638+
free(env->collision_weights);
1639+
free(env->offroad_weights);
1640+
free(env->goal_weights);
1641+
}
1642+
if (env->use_ec) {
1643+
free(env->entropy_weights);
1644+
}
1645+
if (env->use_dc) {
1646+
free(env->discount_weights);
1647+
}
15971648
}
15981649

15991650
void allocate(Drive* env){
@@ -1606,13 +1657,38 @@ void allocate(Drive* env){
16061657
env->actions = (float*)calloc(env->active_agent_count*2, sizeof(float));
16071658
env->rewards = (float*)calloc(env->active_agent_count, sizeof(float));
16081659
env->terminals= (unsigned char*)calloc(env->active_agent_count, sizeof(unsigned char));
1660+
1661+
if (env->use_rc) {
1662+
env->collision_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1663+
env->offroad_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1664+
env->goal_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1665+
}
1666+
if (env->use_ec) {
1667+
env->entropy_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1668+
}
1669+
if (env->use_dc) {
1670+
env->discount_weights = (float*)calloc(env->active_agent_count, sizeof(float));
1671+
}
16091672
}
16101673

16111674
void free_allocated(Drive* env){
16121675
free(env->observations);
16131676
free(env->actions);
16141677
free(env->rewards);
16151678
free(env->terminals);
1679+
1680+
if (env->use_rc) {
1681+
free(env->collision_weights);
1682+
free(env->offroad_weights);
1683+
free(env->goal_weights);
1684+
}
1685+
if (env->use_ec) {
1686+
free(env->entropy_weights);
1687+
}
1688+
if (env->use_dc) {
1689+
free(env->discount_weights);
1690+
}
1691+
16161692
c_close(env);
16171693
}
16181694

@@ -1704,10 +1780,6 @@ void compute_observations(Drive* env) {
17041780
float* obs = &observations[i][0];
17051781
Entity* ego_entity = &env->entities[env->active_agent_indices[i]];
17061782
if(ego_entity->type > 3) break;
1707-
if(ego_entity->respawn_timestep != -1) {
1708-
obs[6] = 1;
1709-
//continue;
1710-
}
17111783
float cos_heading = ego_entity->heading_x;
17121784
float sin_heading = ego_entity->heading_y;
17131785
float ego_speed = sqrtf(ego_entity->vx*ego_entity->vx + ego_entity->vy*ego_entity->vy);
@@ -1726,9 +1798,26 @@ void compute_observations(Drive* env) {
17261798
obs[3] = ego_entity->width / MAX_VEH_WIDTH;
17271799
obs[4] = ego_entity->length / MAX_VEH_LEN;
17281800
obs[5] = (ego_entity->collision_state > 0) ? 1.0f : 0.0f;
1801+
if(ego_entity->respawn_timestep != -1) {
1802+
obs[6] = 1;
1803+
//continue;
1804+
}
1805+
1806+
// Add conditioning weights to observations
1807+
int obs_idx = 7;
1808+
if (env->use_rc) {
1809+
obs[obs_idx++] = env->collision_weights[i];
1810+
obs[obs_idx++] = env->offroad_weights[i];
1811+
obs[obs_idx++] = env->goal_weights[i];
1812+
}
1813+
if (env->use_ec) {
1814+
obs[obs_idx++] = env->entropy_weights[i];
1815+
}
1816+
if (env->use_dc) {
1817+
obs[obs_idx++] = env->discount_weights[i];
1818+
}
17291819

17301820
// Relative Pos of other cars
1731-
int obs_idx = 7; // Start after goal distances
17321821
int cars_seen = 0;
17331822
for(int j = 0; j < MAX_AGENTS; j++) {
17341823
int index = -1;
@@ -1969,6 +2058,28 @@ void compute_new_goal(Drive* env, int agent_idx) {
19692058
void c_reset(Drive* env){
19702059
env->timestep = env->init_steps;
19712060
set_start_position(env);
2061+
2062+
// Initialize conditioning weights
2063+
if (env->use_rc) {
2064+
for(int i = 0; i < env->active_agent_count; i++) {
2065+
env->collision_weights[i] = ((float)rand() / RAND_MAX) * (env->collision_weight_ub - env->collision_weight_lb) + env->collision_weight_lb;
2066+
env->offroad_weights[i] = ((float)rand() / RAND_MAX) * (env->offroad_weight_ub - env->offroad_weight_lb) + env->offroad_weight_lb;
2067+
env->goal_weights[i] = ((float)rand() / RAND_MAX) * (env->goal_weight_ub - env->goal_weight_lb) + env->goal_weight_lb;
2068+
}
2069+
}
2070+
2071+
if (env->use_ec) {
2072+
for(int i = 0; i < env->active_agent_count; i++) {
2073+
env->entropy_weights[i] = ((float)rand() / RAND_MAX) * (env->entropy_weight_ub - env->entropy_weight_lb) + env->entropy_weight_lb;
2074+
}
2075+
}
2076+
2077+
if (env->use_dc) {
2078+
for(int i = 0; i < env->active_agent_count; i++) {
2079+
env->discount_weights[i] = ((float)rand() / RAND_MAX) * (env->discount_weight_ub - env->discount_weight_lb) + env->discount_weight_lb;
2080+
}
2081+
}
2082+
19722083
for(int x = 0;x<env->active_agent_count; x++){
19732084
env->logs[x] = (Log){0};
19742085
int agent_idx = env->active_agent_indices[x];

0 commit comments

Comments
 (0)