|
1 | | -import sys |
2 | | -import time |
3 | 1 | from copy import deepcopy |
4 | 2 | from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union |
5 | 3 |
|
|
11 | 9 | from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm |
12 | 10 | from stable_baselines3.common.policies import BasePolicy |
13 | 11 | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule |
14 | | -from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean |
| 12 | +from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor |
15 | 13 | from stable_baselines3.common.vec_env import VecEnv |
16 | 14 |
|
17 | 15 | from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer |
@@ -260,7 +258,7 @@ def collect_rollouts( |
260 | 258 | if not callback.on_step(): |
261 | 259 | return False |
262 | 260 |
|
263 | | - self._update_info_buffer(infos) |
| 261 | + self._update_info_buffer(infos, dones) |
264 | 262 | n_steps += 1 |
265 | 263 |
|
266 | 264 | if isinstance(self.action_space, spaces.Discrete): |
@@ -453,42 +451,11 @@ def learn( |
453 | 451 | reset_num_timesteps: bool = True, |
454 | 452 | progress_bar: bool = False, |
455 | 453 | ) -> SelfRecurrentPPO: |
456 | | - iteration = 0 |
457 | | - |
458 | | - total_timesteps, callback = self._setup_learn( |
459 | | - total_timesteps, |
460 | | - callback, |
461 | | - reset_num_timesteps, |
462 | | - tb_log_name, |
463 | | - progress_bar, |
| 454 | + return super().learn( |
| 455 | + total_timesteps=total_timesteps, |
| 456 | + callback=callback, |
| 457 | + log_interval=log_interval, |
| 458 | + tb_log_name=tb_log_name, |
| 459 | + reset_num_timesteps=reset_num_timesteps, |
| 460 | + progress_bar=progress_bar, |
464 | 461 | ) |
465 | | - |
466 | | - callback.on_training_start(locals(), globals()) |
467 | | - |
468 | | - while self.num_timesteps < total_timesteps: |
469 | | - continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) |
470 | | - |
471 | | - if not continue_training: |
472 | | - break |
473 | | - |
474 | | - iteration += 1 |
475 | | - self._update_current_progress_remaining(self.num_timesteps, total_timesteps) |
476 | | - |
477 | | - # Display training infos |
478 | | - if log_interval is not None and iteration % log_interval == 0: |
479 | | - time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon) |
480 | | - fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed) |
481 | | - self.logger.record("time/iterations", iteration, exclude="tensorboard") |
482 | | - if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: |
483 | | - self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) |
484 | | - self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) |
485 | | - self.logger.record("time/fps", fps) |
486 | | - self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") |
487 | | - self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") |
488 | | - self.logger.dump(step=self.num_timesteps) |
489 | | - |
490 | | - self.train() |
491 | | - |
492 | | - callback.on_training_end() |
493 | | - |
494 | | - return self |
0 commit comments