Skip to content

Commit dfd7d49

Browse files
committed
added reward/entropy/discount conditioning
1 parent c4d58dd commit dfd7d49

File tree

11 files changed

+407
-41
lines changed

11 files changed

+407
-41
lines changed

pufferlib/config/ocean/drive.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ init_steps = 0 # Determines which step of the trajectory to initialize the agent
3535
control_all_agents = False # this should be set to false unless you want to specifically want to override and control expert marked vehicles
3636
num_policy_controlled_agents = -1 # note: if you add this you likely need to set num_agents to a smaller number
3737
deterministic_agent_selection = False # if this is true it overrides vehicles marked as expert to be policy controlled
38+
condition_type = "none" # Options: "none", "reward", "entropy", "discount", "all"
3839

3940
[train]
4041
total_timesteps = 2_000_000_000

pufferlib/extensions/cuda/pufferlib.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ __host__ __device__ void puff_advantage_row_cuda(float* values, float* rewards,
2020
}
2121

2222
void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
23-
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
24-
int num_steps, int horizon) {
23+
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
24+
torch::Tensor gammas, int num_steps, int horizon) {
2525

2626
// Validate input tensors
2727
torch::Device device = values.device();
@@ -35,24 +35,30 @@ void vtrace_check_cuda(torch::Tensor values, torch::Tensor rewards,
3535
t.contiguous();
3636
}
3737
}
38+
// Validate gammas tensor
39+
TORCH_CHECK(gammas.dim() == 1, "Gammas must be 1D");
40+
TORCH_CHECK(gammas.size(0) == num_steps, "Gammas size must match num_steps");
41+
TORCH_CHECK(gammas.dtype() == torch::kFloat32, "Gammas must be float32");
42+
TORCH_CHECK(gammas.is_cuda(), "Gammas must be on GPU");
43+
TORCH_CHECK(gammas.is_contiguous(), "Gammas must be contiguous");
3844
}
3945

4046
// [num_steps, horizon]
4147
__global__ void puff_advantage_kernel(float* values, float* rewards,
42-
float* dones, float* importance, float* advantages, float gamma,
48+
float* dones, float* importance, float* advantages, float* gammas,
4349
float lambda, float rho_clip, float c_clip, int num_steps, int horizon) {
4450
int row = blockIdx.x*blockDim.x + threadIdx.x;
4551
if (row >= num_steps) {
4652
return;
4753
}
4854
int offset = row*horizon;
4955
puff_advantage_row_cuda(values + offset, rewards + offset, dones + offset,
50-
importance + offset, advantages + offset, gamma, lambda, rho_clip, c_clip, horizon);
56+
importance + offset, advantages + offset, gammas[row], lambda, rho_clip, c_clip, horizon);
5157
}
5258

5359
void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
5460
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
55-
double gamma, double lambda, double rho_clip, double c_clip) {
61+
torch::Tensor gammas, double lambda, double rho_clip, double c_clip) {
5662
int num_steps = values.size(0);
5763
int horizon = values.size(1);
5864
vtrace_check_cuda(values, rewards, dones, importance, advantages, num_steps, horizon);
@@ -67,7 +73,7 @@ void compute_puff_advantage_cuda(torch::Tensor values, torch::Tensor rewards,
6773
dones.data_ptr<float>(),
6874
importance.data_ptr<float>(),
6975
advantages.data_ptr<float>(),
70-
gamma,
76+
gammas.data_ptr<float>(),
7177
lambda,
7278
rho_clip,
7379
c_clip,

pufferlib/extensions/pufferlib.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void puff_advantage_row(float* values, float* rewards, float* dones,
4242

4343
void vtrace_check(torch::Tensor values, torch::Tensor rewards,
4444
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
45-
int num_steps, int horizon) {
45+
torch::Tensor gammas, int num_steps, int horizon) {
4646

4747
// Validate input tensors
4848
torch::Device device = values.device();
@@ -56,36 +56,42 @@ void vtrace_check(torch::Tensor values, torch::Tensor rewards,
5656
t.contiguous();
5757
}
5858
}
59+
// Validate gammas tensor
60+
TORCH_CHECK(gammas.dim() == 1, "Gammas must be 1D");
61+
TORCH_CHECK(gammas.size(0) == num_steps, "Gammas size must match num_steps");
62+
TORCH_CHECK(gammas.dtype() == torch::kFloat32, "Gammas must be float32");
63+
TORCH_CHECK(gammas.is_contiguous(), "Gammas must be contiguous");
5964
}
6065

6166

6267
// [num_steps, horizon]
6368
void puff_advantage(float* values, float* rewards, float* dones, float* importance,
64-
float* advantages, float gamma, float lambda, float rho_clip, float c_clip,
69+
float* advantages, float* gammas, float lambda, float rho_clip, float c_clip,
6570
int num_steps, const int horizon){
66-
for (int offset = 0; offset < num_steps*horizon; offset+=horizon) {
71+
for (int row = 0; row < num_steps; row++) {
72+
int offset = row * horizon;
6773
puff_advantage_row(values + offset, rewards + offset,
6874
dones + offset, importance + offset, advantages + offset,
69-
gamma, lambda, rho_clip, c_clip, horizon
75+
gammas[row], lambda, rho_clip, c_clip, horizon
7076
);
7177
}
7278
}
7379

7480

7581
void compute_puff_advantage_cpu(torch::Tensor values, torch::Tensor rewards,
7682
torch::Tensor dones, torch::Tensor importance, torch::Tensor advantages,
77-
double gamma, double lambda, double rho_clip, double c_clip) {
83+
torch::Tensor gammas, double lambda, double rho_clip, double c_clip) {
7884
int num_steps = values.size(0);
7985
int horizon = values.size(1);
80-
vtrace_check(values, rewards, dones, importance, advantages, num_steps, horizon);
86+
vtrace_check(values, rewards, dones, importance, advantages, gammas, num_steps, horizon);
8187
puff_advantage(values.data_ptr<float>(), rewards.data_ptr<float>(),
8288
dones.data_ptr<float>(), importance.data_ptr<float>(), advantages.data_ptr<float>(),
83-
gamma, lambda, rho_clip, c_clip, num_steps, horizon
89+
gammas.data_ptr<float>(), lambda, rho_clip, c_clip, num_steps, horizon
8490
);
8591
}
8692

8793
TORCH_LIBRARY(pufferlib, m) {
88-
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, float gamma, float lambda, float rho_clip, float c_clip) -> ()");
94+
m.def("compute_puff_advantage(Tensor(a!) values, Tensor(b!) rewards, Tensor(c!) dones, Tensor(d!) importance, Tensor(e!) advantages, Tensor gammas, float lambda, float rho_clip, float c_clip) -> ()");
8995
}
9096

9197
TORCH_LIBRARY_IMPL(pufferlib, CPU, m) {

pufferlib/ocean/drive/binding.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,22 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
174174
env->control_all_agents = unpack(kwargs, "control_all_agents");
175175
env->deterministic_agent_selection = unpack(kwargs, "deterministic_agent_selection");
176176
env->control_non_vehicles = (int)unpack(kwargs, "control_non_vehicles");
177+
178+
// Conditioning parameters
179+
env->use_rc = (bool)unpack(kwargs, "use_rc");
180+
env->use_ec = (bool)unpack(kwargs, "use_ec");
181+
env->use_dc = (bool)unpack(kwargs, "use_dc");
182+
env->collision_weight_lb = (float)unpack(kwargs, "collision_weight_lb");
183+
env->collision_weight_ub = (float)unpack(kwargs, "collision_weight_ub");
184+
env->offroad_weight_lb = (float)unpack(kwargs, "offroad_weight_lb");
185+
env->offroad_weight_ub = (float)unpack(kwargs, "offroad_weight_ub");
186+
env->goal_weight_lb = (float)unpack(kwargs, "goal_weight_lb");
187+
env->goal_weight_ub = (float)unpack(kwargs, "goal_weight_ub");
188+
env->entropy_weight_lb = (float)unpack(kwargs, "entropy_weight_lb");
189+
env->entropy_weight_ub = (float)unpack(kwargs, "entropy_weight_ub");
190+
env->discount_weight_lb = (float)unpack(kwargs, "discount_weight_lb");
191+
env->discount_weight_ub = (float)unpack(kwargs, "discount_weight_ub");
192+
177193
int map_id = unpack(kwargs, "map_id");
178194
int max_agents = unpack(kwargs, "max_agents");
179195
int init_steps = unpack(kwargs, "init_steps");

pufferlib/ocean/drive/drive.c

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
typedef struct DriveNet DriveNet;
1212
struct DriveNet {
1313
int num_agents;
14+
int conditioning_dims;
1415
float* obs_self;
1516
float* obs_partner;
1617
float* obs_road;
@@ -42,13 +43,20 @@ struct DriveNet {
4243
Multidiscrete* multidiscrete;
4344
};
4445

45-
DriveNet* init_drivenet(Weights* weights, int num_agents) {
46+
DriveNet* init_drivenet(Weights* weights, int num_agents, bool use_rc, bool use_ec, bool use_dc) {
4647
DriveNet* net = calloc(1, sizeof(DriveNet));
4748
int hidden_size = 256;
4849
int input_size = 64;
4950

5051
net->num_agents = num_agents;
51-
net->obs_self = calloc(num_agents*7, sizeof(float)); // 7 features
52+
net->conditioning_dims = (use_rc ? 3 : 0) + (use_ec ? 1 : 0) + (use_dc ? 1 : 0);
53+
54+
int ego_obs_size = 7; // base features
55+
if (use_rc) ego_obs_size += 3; // reward conditioning
56+
if (use_ec) ego_obs_size += 1; // entropy conditioning
57+
if (use_dc) ego_obs_size += 1; // discount conditioning
58+
59+
net->obs_self = calloc(num_agents*ego_obs_size, sizeof(float));
5260
net->obs_partner = calloc(num_agents*63*7, sizeof(float)); // 63 objects, 7 features
5361
net->obs_road = calloc(num_agents*200*13, sizeof(float)); // 200 objects, 13 features
5462
net->partner_linear_output = calloc(num_agents*63*input_size, sizeof(float));
@@ -57,7 +65,7 @@ DriveNet* init_drivenet(Weights* weights, int num_agents) {
5765
net->road_linear_output_two = calloc(num_agents*200*input_size, sizeof(float));
5866
net->partner_layernorm_output = calloc(num_agents*63*input_size, sizeof(float));
5967
net->road_layernorm_output = calloc(num_agents*200*input_size, sizeof(float));
60-
net->ego_encoder = make_linear(weights, num_agents, 7, input_size);
68+
net->ego_encoder = make_linear(weights, num_agents, ego_obs_size, input_size);
6169
net->ego_layernorm = make_layernorm(weights, num_agents, input_size);
6270
net->ego_encoder_two = make_linear(weights, num_agents, input_size, input_size);
6371
net->road_encoder = make_linear(weights, num_agents, 13, input_size);
@@ -117,23 +125,25 @@ void free_drivenet(DriveNet* net) {
117125
}
118126

119127
void forward(DriveNet* net, float* observations, int* actions) {
128+
int ego_obs_size = 7 + net->conditioning_dims;
129+
120130
// Clear previous observations
121-
memset(net->obs_self, 0, net->num_agents * 7 * sizeof(float));
131+
memset(net->obs_self, 0, net->num_agents * ego_obs_size * sizeof(float));
122132
memset(net->obs_partner, 0, net->num_agents * 63 * 7 * sizeof(float));
123133
memset(net->obs_road, 0, net->num_agents * 200 * 13 * sizeof(float));
124134

125135
// Reshape observations into 2D boards and additional features
126-
float (*obs_self)[7] = (float (*)[7])net->obs_self;
136+
float* obs_self = net->obs_self;
127137
float (*obs_partner)[63][7] = (float (*)[63][7])net->obs_partner;
128138
float (*obs_road)[200][13] = (float (*)[200][13])net->obs_road;
129139

130140
for (int b = 0; b < net->num_agents; b++) {
131-
int b_offset = b * (7 + 63*7 + 200*7); // offset for each batch
132-
int partner_offset = b_offset + 7;
133-
int road_offset = b_offset + 7 + 63*7;
141+
int b_offset = b * (ego_obs_size + 63*7 + 200*7); // offset for each batch
142+
int partner_offset = b_offset + ego_obs_size;
143+
int road_offset = b_offset + ego_obs_size + 63*7;
134144
// Process self observation
135-
for(int i = 0; i < 7; i++) {
136-
obs_self[b][i] = observations[b_offset + i];
145+
for(int i = 0; i < ego_obs_size; i++) {
146+
obs_self[b*ego_obs_size + i] = observations[b_offset + i];
137147
}
138148

139149
// Process partner observation

0 commit comments

Comments
 (0)