@@ -119,6 +119,12 @@ struct Log {
119119 float active_agent_count ;
120120 float expert_static_car_count ;
121121 float static_car_count ;
122+ // Conditioning metrics
123+ float avg_collision_weight ;
124+ float avg_offroad_weight ;
125+ float avg_goal_weight ;
126+ float avg_entropy_weight ;
127+ float avg_discount_weight ;
122128};
123129
124130typedef struct Entity Entity ;
@@ -287,6 +293,27 @@ struct Drive {
287293 char * ini_file ;
288294 int scenario_length ;
289295 int control_non_vehicles ;
296+ // Reward conditioning
297+ bool use_rc ;
298+ float collision_weight_lb ;
299+ float collision_weight_ub ;
300+ float offroad_weight_lb ;
301+ float offroad_weight_ub ;
302+ float goal_weight_lb ;
303+ float goal_weight_ub ;
304+ float * collision_weights ;
305+ float * offroad_weights ;
306+ float * goal_weights ;
307+ // Entropy conditioning
308+ bool use_ec ;
309+ float entropy_weight_lb ;
310+ float entropy_weight_ub ;
311+ float * entropy_weights ;
312+ // Discount conditioning
313+ bool use_dc ;
314+ float discount_weight_lb ;
315+ float discount_weight_ub ;
316+ float * discount_weights ;
290317};
291318
292319typedef struct {
@@ -1565,6 +1592,18 @@ void init(Drive* env){
15651592 set_start_position (env );
15661593 init_goal_positions (env );
15671594 env -> logs = (Log * )calloc (env -> active_agent_count , sizeof (Log ));
1595+
1596+ if (env -> use_rc ) {
1597+ env -> collision_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1598+ env -> offroad_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1599+ env -> goal_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1600+ }
1601+ if (env -> use_ec ) {
1602+ env -> entropy_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1603+ }
1604+ if (env -> use_dc ) {
1605+ env -> discount_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1606+ }
15681607}
15691608
15701609void c_close (Drive * env ){
@@ -1594,6 +1633,18 @@ void c_close(Drive* env){
15941633 freeTopologyGraph (env -> topology_graph );
15951634 // free(env->map_name);
15961635 free (env -> ini_file );
1636+
1637+ if (env -> use_rc ) {
1638+ free (env -> collision_weights );
1639+ free (env -> offroad_weights );
1640+ free (env -> goal_weights );
1641+ }
1642+ if (env -> use_ec ) {
1643+ free (env -> entropy_weights );
1644+ }
1645+ if (env -> use_dc ) {
1646+ free (env -> discount_weights );
1647+ }
15971648}
15981649
15991650void allocate (Drive * env ){
@@ -1606,13 +1657,38 @@ void allocate(Drive* env){
16061657 env -> actions = (float * )calloc (env -> active_agent_count * 2 , sizeof (float ));
16071658 env -> rewards = (float * )calloc (env -> active_agent_count , sizeof (float ));
16081659 env -> terminals = (unsigned char * )calloc (env -> active_agent_count , sizeof (unsigned char ));
1660+
1661+ if (env -> use_rc ) {
1662+ env -> collision_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1663+ env -> offroad_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1664+ env -> goal_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1665+ }
1666+ if (env -> use_ec ) {
1667+ env -> entropy_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1668+ }
1669+ if (env -> use_dc ) {
1670+ env -> discount_weights = (float * )calloc (env -> active_agent_count , sizeof (float ));
1671+ }
16091672}
16101673
16111674void free_allocated (Drive * env ){
16121675 free (env -> observations );
16131676 free (env -> actions );
16141677 free (env -> rewards );
16151678 free (env -> terminals );
1679+
1680+ if (env -> use_rc ) {
1681+ free (env -> collision_weights );
1682+ free (env -> offroad_weights );
1683+ free (env -> goal_weights );
1684+ }
1685+ if (env -> use_ec ) {
1686+ free (env -> entropy_weights );
1687+ }
1688+ if (env -> use_dc ) {
1689+ free (env -> discount_weights );
1690+ }
1691+
16161692 c_close (env );
16171693}
16181694
@@ -1704,10 +1780,6 @@ void compute_observations(Drive* env) {
17041780 float * obs = & observations [i ][0 ];
17051781 Entity * ego_entity = & env -> entities [env -> active_agent_indices [i ]];
17061782 if (ego_entity -> type > 3 ) break ;
1707- if (ego_entity -> respawn_timestep != -1 ) {
1708- obs [6 ] = 1 ;
1709- //continue;
1710- }
17111783 float cos_heading = ego_entity -> heading_x ;
17121784 float sin_heading = ego_entity -> heading_y ;
17131785 float ego_speed = sqrtf (ego_entity -> vx * ego_entity -> vx + ego_entity -> vy * ego_entity -> vy );
@@ -1726,9 +1798,26 @@ void compute_observations(Drive* env) {
17261798 obs [3 ] = ego_entity -> width / MAX_VEH_WIDTH ;
17271799 obs [4 ] = ego_entity -> length / MAX_VEH_LEN ;
17281800 obs [5 ] = (ego_entity -> collision_state > 0 ) ? 1.0f : 0.0f ;
1801+ if (ego_entity -> respawn_timestep != -1 ) {
1802+ obs [6 ] = 1 ;
1803+ //continue;
1804+ }
1805+
1806+ // Add conditioning weights to observations
1807+ int obs_idx = 7 ;
1808+ if (env -> use_rc ) {
1809+ obs [obs_idx ++ ] = env -> collision_weights [i ];
1810+ obs [obs_idx ++ ] = env -> offroad_weights [i ];
1811+ obs [obs_idx ++ ] = env -> goal_weights [i ];
1812+ }
1813+ if (env -> use_ec ) {
1814+ obs [obs_idx ++ ] = env -> entropy_weights [i ];
1815+ }
1816+ if (env -> use_dc ) {
1817+ obs [obs_idx ++ ] = env -> discount_weights [i ];
1818+ }
17291819
17301820 // Relative Pos of other cars
1731- int obs_idx = 7 ; // Start after goal distances
17321821 int cars_seen = 0 ;
17331822 for (int j = 0 ; j < MAX_AGENTS ; j ++ ) {
17341823 int index = -1 ;
@@ -1969,6 +2058,28 @@ void compute_new_goal(Drive* env, int agent_idx) {
19692058void c_reset (Drive * env ){
19702059 env -> timestep = env -> init_steps ;
19712060 set_start_position (env );
2061+
2062+ // Initialize conditioning weights
2063+ if (env -> use_rc ) {
2064+ for (int i = 0 ; i < env -> active_agent_count ; i ++ ) {
2065+ env -> collision_weights [i ] = ((float )rand () / RAND_MAX ) * (env -> collision_weight_ub - env -> collision_weight_lb ) + env -> collision_weight_lb ;
2066+ env -> offroad_weights [i ] = ((float )rand () / RAND_MAX ) * (env -> offroad_weight_ub - env -> offroad_weight_lb ) + env -> offroad_weight_lb ;
2067+ env -> goal_weights [i ] = ((float )rand () / RAND_MAX ) * (env -> goal_weight_ub - env -> goal_weight_lb ) + env -> goal_weight_lb ;
2068+ }
2069+ }
2070+
2071+ if (env -> use_ec ) {
2072+ for (int i = 0 ; i < env -> active_agent_count ; i ++ ) {
2073+ env -> entropy_weights [i ] = ((float )rand () / RAND_MAX ) * (env -> entropy_weight_ub - env -> entropy_weight_lb ) + env -> entropy_weight_lb ;
2074+ }
2075+ }
2076+
2077+ if (env -> use_dc ) {
2078+ for (int i = 0 ; i < env -> active_agent_count ; i ++ ) {
2079+ env -> discount_weights [i ] = ((float )rand () / RAND_MAX ) * (env -> discount_weight_ub - env -> discount_weight_lb ) + env -> discount_weight_lb ;
2080+ }
2081+ }
2082+
19722083 for (int x = 0 ;x < env -> active_agent_count ; x ++ ){
19732084 env -> logs [x ] = (Log ){0 };
19742085 int agent_idx = env -> active_agent_indices [x ];
0 commit comments