@@ -6,23 +6,42 @@ use qtable::strategy;
6
6
/// 元の Python コードベースを尊重して離散的な AcrobotState を用いている
7
7
fn get_reward ( task : & AcrobotBalanceTask , state : & AcrobotState , action : & AcrobotAction ) -> f64 {
8
8
if task. should_finish_episode ( state) {
9
- return -2000.0 ; // Penalty for finishing the episode
9
+ return -2000.0 ;
10
+ }
11
+
12
+ let pend_pos_center = task. n_pendulum_digitization as f64 / 2.0 ;
13
+ let arm_pos_center = task. n_arm_digitization as f64 / 2.0 ;
14
+ let pend_vel_center = task. n_pendulum_digitization as f64 / 2.0 ;
15
+ let arm_vel_center = task. n_arm_digitization as f64 / 2.0 ;
16
+
17
+ if ( state. n_pendulum_rad as f64 - pend_pos_center) . abs ( ) < 1.0 &&
18
+ ( state. n_arm_rad as f64 - arm_pos_center) . abs ( ) < 1.0 &&
19
+ ( state. n_pendulum_vel as f64 - pend_vel_center) . abs ( ) < 1.0 &&
20
+ ( state. n_arm_vel as f64 - arm_vel_center) . abs ( ) < 1.0
21
+ {
22
+ return 500.0 ;
10
23
}
11
24
12
25
let position_reward = {
13
26
const MAX_REWARD : f64 = 10.0 ;
14
- let pendulum_pos_error = ( task . n_pendulum_digitization as f64 / 2.0 - state. n_pendulum_rad as f64 ) . abs ( ) . floor ( ) . powi ( 2 ) ;
15
- let arm_pos_error = ( task . n_arm_digitization as f64 / 2.0 - state. n_arm_rad as f64 ) . abs ( ) . floor ( ) . powi ( 2 ) ;
27
+ let pendulum_pos_error = ( pend_pos_center - state. n_pendulum_rad as f64 ) . powi ( 2 ) ;
28
+ let arm_pos_error = ( arm_pos_center - state. n_arm_rad as f64 ) . powi ( 2 ) ;
16
29
MAX_REWARD - ( 0.2 * pendulum_pos_error + 0.1 * arm_pos_error)
17
30
} ;
18
31
19
32
let velocity_penalty = {
20
- let pendulum_vel_error = ( task . n_pendulum_digitization as f64 / 2.0 - state. n_pendulum_vel as f64 ) . abs ( ) . floor ( ) . powi ( 2 ) ;
21
- let arm_vel_error = ( task . n_pendulum_digitization as f64 / 2.0 - state. n_arm_vel as f64 ) . abs ( ) . floor ( ) . powi ( 2 ) ;
33
+ let pendulum_vel_error = ( pend_vel_center - state. n_pendulum_vel as f64 ) . powi ( 2 ) ;
34
+ let arm_vel_error = ( arm_vel_center - state. n_arm_vel as f64 ) . powi ( 2 ) ;
22
35
0.001 * pendulum_vel_error + 0.002 * arm_vel_error
23
36
} ;
24
37
25
- let action_cost = { 0.02 * ( task. action_size as f64 / 2.0 - action. digitization_index as f64 ) . abs ( ) . floor ( ) } ;
38
+ // この場合 action は「何もしない」真ん中が設定されていることを想定している。
39
+ // よって task.action_size は奇数を前提として、真ん中の index のときに
40
+ // action_cost が厳密に 0 になるようにしている
41
+ let action_cost = {
42
+ let center_action_index = ( task. action_size - 1 ) as f64 / 2.0 ;
43
+ 0.01 * ( center_action_index - action. digitization_index as f64 ) . abs ( )
44
+ } ;
26
45
27
46
position_reward - velocity_penalty - action_cost
28
47
}
0 commit comments