Skip to content

Commit 4af2ea8

Browse files
committed
refactor around trained agent
1 parent ae0e5a1 commit 4af2ea8

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

examples/acrobot-qtable/src/bin/simulate.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use oxide_control::TimeStep;
33
use oxide_control::physics::binding::{
44
mjr_makeContext, mjr_render, mjrContext, mjrRect, mjtCatBit, mjtFontScale, mjv_makeScene, mjv_updateScene, mjvCamera, mjvOption, mjvScene,
55
};
6-
use qtable::strategy;
76

87
fn main() {
98
let mut args = std::env::args().skip(1);
@@ -43,7 +42,7 @@ fn main() {
4342
let mut obs = env.reset();
4443
while !window.should_close() {
4544
while env.physics().time() < glfw.get_time() {
46-
match env.step(t.get_action::<strategy::MostQValue>(env.task().state(&obs))) {
45+
match env.step(t.get_action(env.task().state(&obs))) {
4746
TimeStep::Step { observation, .. } => {
4847
obs = observation;
4948
}

examples/acrobot-qtable/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ impl TrainedAgent {
295295
self.0.n_pendulum_digitization
296296
}
297297

298-
pub fn get_action<S: qtable::Strategy>(&self, state: AcrobotState) -> AcrobotAction {
299-
self.0.get_action::<S>(state)
298+
pub fn get_action(&self, state: AcrobotState) -> AcrobotAction {
299+
self.0.get_action::<qtable::strategy::MostQValue>(state)
300300
}
301301
}

0 commit comments

Comments
 (0)