From 324266f8307bdac648bdd9968452bbe6f61297a9 Mon Sep 17 00:00:00 2001 From: Ryan J <103715677+RJ-246@users.noreply.github.com> Date: Sun, 22 Oct 2023 18:18:02 -0600 Subject: [PATCH 1/2] Update run_baseline_parallel_fast.py Added a function that allows users to select a checkpoint for the training session. --- baselines/run_baseline_parallel_fast.py | 38 +++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/baselines/run_baseline_parallel_fast.py b/baselines/run_baseline_parallel_fast.py index cd19dbdf0..9bf9a934d 100644 --- a/baselines/run_baseline_parallel_fast.py +++ b/baselines/run_baseline_parallel_fast.py @@ -1,4 +1,5 @@ from os.path import exists +import os from pathlib import Path import uuid from red_gym_env import RedGymEnv @@ -23,8 +24,42 @@ def _init(): set_random_seed(seed) return _init +#Allows user to select checkpoint to use +def choose_checkpoint(): + session_list = [] + base_dir = Path(__file__).resolve().parent #__name__? + for x in os.listdir(base_dir): + if 'session' in x: + session_list.append({"session": x, "modified": os.path.getmtime(f"{base_dir}/{x}")}) + ordered_session_list = sorted(session_list, key=lambda x: x["modified"], reverse=True) + ordered_session_list.append({"session": "None", "modified": "0"}) + + print("Sessions (1 being the most recent):") + for index, each in enumerate(ordered_session_list): + print(f'{index + 1}: {each["session"]}') + + session_selection = int(input("Pick a session number to use (default is none): ") or 0) + + session = ordered_session_list[session_selection - 1]["session"] + + step_list = [] + for x in os.listdir(f"{base_dir}/{session}"): + if "_steps.zip" in x: + step_list.append(x) + step_list = sorted(step_list, key=lambda x:x[5:x.index("_steps")], reverse=True) + if step_list: + step = step_list[0][0:-4] + path_name = f"{session}/{step}" + else: + print("\n\n\nNo checkpoint found, starting training from scratch.\n\n\n") + path_name = "" + return(path_name) + + if __name__ == '__main__': + # put a checkpoint here you want to start from + file_name = choose_checkpoint() ep_length = 2048 * 10 sess_path = Path(f'session_{str(uuid.uuid4())[:8]}') @@ -47,8 +82,7 @@ def _init(): name_prefix='poke') #env_checker.check_env(env) learn_steps = 40 - # put a checkpoint here you want to start from - file_name = 'session_e41c9eff/poke_38207488_steps' + if exists(file_name + '.zip'): print('\nloading checkpoint') From 95b837f9dee87564a1dcd02f64d9458fdc0dcf2a Mon Sep 17 00:00:00 2001 From: Ryan J <103715677+RJ-246@users.noreply.github.com> Date: Sun, 22 Oct 2023 18:21:10 -0600 Subject: [PATCH 2/2] Update run_baseline_parallel.py Added option for user to select checkpoint --- baselines/run_baseline_parallel.py | 36 ++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/baselines/run_baseline_parallel.py b/baselines/run_baseline_parallel.py index f4423a3a5..705dc747e 100644 --- a/baselines/run_baseline_parallel.py +++ b/baselines/run_baseline_parallel.py @@ -1,4 +1,5 @@ from os.path import exists +import os from pathlib import Path import uuid from red_gym_env import RedGymEnv @@ -23,9 +24,40 @@ def _init(): set_random_seed(seed) return _init -if __name__ == '__main__': +def choose_checkpoint(): + session_list = [] + base_dir = Path(__file__).resolve().parent #__name__? + for x in os.listdir(base_dir): + if 'session' in x: + session_list.append({"session": x, "modified": os.path.getmtime(f"{base_dir}/{x}")}) + ordered_session_list = sorted(session_list, key=lambda x: x["modified"], reverse=True) + ordered_session_list.append({"session": "None", "modified": "0"}) + + print("Sessions (1 being the most recent):") + for index, each in enumerate(ordered_session_list): + print(f'{index + 1}: {each["session"]}') + + session_selection = int(input("Pick a session number to use (default is none): ") or 0) + + session = ordered_session_list[session_selection - 1]["session"] + + step_list = [] + for x in os.listdir(f"{base_dir}/{session}"): + if "_steps.zip" in x: + step_list.append(x) + step_list = sorted(step_list, key=lambda x:x[5:x.index("_steps")], reverse=True) + if step_list: + step = step_list[0][0:-4] + path_name = f"{session}/{step}" + else: + print("\n\n\nNo checkpoint found, starting training from scratch.\n\n\n") + path_name = "" + return(path_name) + +if __name__ == '__main__': + file_name = choose_checkpoint() ep_length = 2048 * 8 sess_path = Path(f'session_{str(uuid.uuid4())[:8]}') @@ -45,7 +77,7 @@ def _init(): name_prefix='poke') #env_checker.check_env(env) learn_steps = 40 - file_name = 'session_e41c9eff/poke_38207488_steps' #'session_e41c9eff/poke_250871808_steps' + #'session_e41c9eff/poke_250871808_steps' #'session_bfdca25a/poke_42532864_steps' #'session_d3033abb/poke_47579136_steps' #'session_a17cc1f5/poke_33546240_steps' #'session_e4bdca71/poke_8945664_steps' #'session_eb21989e/poke_40255488_steps' #'session_80f70ab4/poke_58982400_steps' if exists(file_name + '.zip'):