Skip to content

Commit 029bd35

Browse files
committed
update reward function
1 parent 136d8bc commit 029bd35

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

examples/acrobot-qtable/src/bin/train.rs

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,42 @@ use qtable::strategy;
66
/// 元の Python コードベースを尊重して離散的な AcrobotState を用いている
77
fn get_reward(task: &AcrobotBalanceTask, state: &AcrobotState, action: &AcrobotAction) -> f64 {
88
if task.should_finish_episode(state) {
9-
return -2000.0; // Penalty for finishing the episode
9+
return -2000.0;
10+
}
11+
12+
let pend_pos_center = task.n_pendulum_digitization as f64 / 2.0;
13+
let arm_pos_center = task.n_arm_digitization as f64 / 2.0;
14+
let pend_vel_center = task.n_pendulum_digitization as f64 / 2.0;
15+
let arm_vel_center = task.n_arm_digitization as f64 / 2.0;
16+
17+
if (state.n_pendulum_rad as f64 - pend_pos_center).abs() < 1.0 &&
18+
(state.n_arm_rad as f64 - arm_pos_center).abs() < 1.0 &&
19+
(state.n_pendulum_vel as f64 - pend_vel_center).abs() < 1.0 &&
20+
(state.n_arm_vel as f64 - arm_vel_center).abs() < 1.0
21+
{
22+
return 500.0;
1023
}
1124

1225
let position_reward = {
1326
const MAX_REWARD: f64 = 10.0;
14-
let pendulum_pos_error = (task.n_pendulum_digitization as f64 / 2.0 - state.n_pendulum_rad as f64).abs().floor().powi(2);
15-
let arm_pos_error = (task.n_arm_digitization as f64 / 2.0 - state.n_arm_rad as f64).abs().floor().powi(2);
27+
let pendulum_pos_error = (pend_pos_center - state.n_pendulum_rad as f64).powi(2);
28+
let arm_pos_error = (arm_pos_center - state.n_arm_rad as f64).powi(2);
1629
MAX_REWARD - (0.2 * pendulum_pos_error + 0.1 * arm_pos_error)
1730
};
1831

1932
let velocity_penalty = {
20-
let pendulum_vel_error = (task.n_pendulum_digitization as f64 / 2.0 - state.n_pendulum_vel as f64).abs().floor().powi(2);
21-
let arm_vel_error = (task.n_pendulum_digitization as f64 / 2.0 - state.n_arm_vel as f64).abs().floor().powi(2);
33+
let pendulum_vel_error = (pend_vel_center - state.n_pendulum_vel as f64).powi(2);
34+
let arm_vel_error = (arm_vel_center - state.n_arm_vel as f64).powi(2);
2235
0.001 * pendulum_vel_error + 0.002 * arm_vel_error
2336
};
2437

25-
let action_cost = { 0.02 * (task.action_size as f64 / 2.0 - action.digitization_index as f64).abs().floor() };
38+
// この場合 action は「何もしない」真ん中が設定されていることを想定している。
39+
// よって task.action_size は奇数を前提として、真ん中の index のときに
40+
// action_cost が厳密に 0 になるようにしている
41+
let action_cost = {
42+
let center_action_index = (task.action_size - 1) as f64 / 2.0;
43+
0.01 * (center_action_index - action.digitization_index as f64).abs()
44+
};
2645

2746
position_reward - velocity_penalty - action_cost
2847
}

rustfmt.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
max_width = 160

0 commit comments

Comments
 (0)