Skip to content

Commit 13685ca

Browse files
authored
Log success rate for PPO variants (#235)
1 parent 667a789 commit 13685ca

File tree

4 files changed

+14
-59
lines changed

4 files changed

+14
-59
lines changed

sb3_contrib/ppo_mask/ppo_mask.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import sys
2-
import time
31
from typing import Any, ClassVar, Dict, Optional, Tuple, Type, TypeVar, Union
42

53
import numpy as np
@@ -10,7 +8,7 @@
108
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
119
from stable_baselines3.common.policies import BasePolicy
1210
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
13-
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
11+
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
1412
from stable_baselines3.common.vec_env import VecEnv
1513
from torch.nn import functional as F
1614

@@ -241,7 +239,7 @@ def collect_rollouts(
241239
if not callback.on_step():
242240
return False
243241

244-
self._update_info_buffer(infos)
242+
self._update_info_buffer(infos, dones)
245243
n_steps += 1
246244

247245
if isinstance(self.action_space, spaces.Discrete):
@@ -463,17 +461,7 @@ def learn( # type: ignore[override]
463461

464462
# Display training infos
465463
if log_interval is not None and iteration % log_interval == 0:
466-
assert self.ep_info_buffer is not None
467-
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
468-
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
469-
self.logger.record("time/iterations", iteration, exclude="tensorboard")
470-
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
471-
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
472-
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
473-
self.logger.record("time/fps", fps)
474-
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
475-
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
476-
self.logger.dump(step=self.num_timesteps)
464+
self._dump_logs(iteration)
477465

478466
self.train()
479467

sb3_contrib/ppo_recurrent/ppo_recurrent.py

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import sys
2-
import time
31
from copy import deepcopy
42
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
53

@@ -11,7 +9,7 @@
119
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
1210
from stable_baselines3.common.policies import BasePolicy
1311
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
14-
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor, safe_mean
12+
from stable_baselines3.common.utils import explained_variance, get_schedule_fn, obs_as_tensor
1513
from stable_baselines3.common.vec_env import VecEnv
1614

1715
from sb3_contrib.common.recurrent.buffers import RecurrentDictRolloutBuffer, RecurrentRolloutBuffer
@@ -260,7 +258,7 @@ def collect_rollouts(
260258
if not callback.on_step():
261259
return False
262260

263-
self._update_info_buffer(infos)
261+
self._update_info_buffer(infos, dones)
264262
n_steps += 1
265263

266264
if isinstance(self.action_space, spaces.Discrete):
@@ -453,42 +451,11 @@ def learn(
453451
reset_num_timesteps: bool = True,
454452
progress_bar: bool = False,
455453
) -> SelfRecurrentPPO:
456-
iteration = 0
457-
458-
total_timesteps, callback = self._setup_learn(
459-
total_timesteps,
460-
callback,
461-
reset_num_timesteps,
462-
tb_log_name,
463-
progress_bar,
454+
return super().learn(
455+
total_timesteps=total_timesteps,
456+
callback=callback,
457+
log_interval=log_interval,
458+
tb_log_name=tb_log_name,
459+
reset_num_timesteps=reset_num_timesteps,
460+
progress_bar=progress_bar,
464461
)
465-
466-
callback.on_training_start(locals(), globals())
467-
468-
while self.num_timesteps < total_timesteps:
469-
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
470-
471-
if not continue_training:
472-
break
473-
474-
iteration += 1
475-
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)
476-
477-
# Display training infos
478-
if log_interval is not None and iteration % log_interval == 0:
479-
time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)
480-
fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)
481-
self.logger.record("time/iterations", iteration, exclude="tensorboard")
482-
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
483-
self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))
484-
self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))
485-
self.logger.record("time/fps", fps)
486-
self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")
487-
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
488-
self.logger.dump(step=self.num_timesteps)
489-
490-
self.train()
491-
492-
callback.on_training_end()
493-
494-
return self

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.3.0a4
1+
2.3.0a5

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
6666
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
6767
install_requires=[
68-
"stable_baselines3>=2.3.0a4,<3.0",
68+
"stable_baselines3>=2.3.0a5,<3.0",
6969
],
7070
description="Contrib package of Stable Baselines3, experimental code.",
7171
author="Antonin Raffin",

0 commit comments

Comments
 (0)