Skip to content

Commit e2a0e10

Browse files
authored
Add basic support for repeating a trial config. (#642)
Adds very rudimentary support for repeating configs across multiple trials. We fully expect to expand on more advanced support for this once we add a proper scheduler (#463). Additional tests forth coming with #633 and related PRs.
1 parent b531400 commit e2a0e10

File tree

5 files changed

+39
-14
lines changed

5 files changed

+39
-14
lines changed

mlos_bench/mlos_bench/config/schemas/cli/cli-schema.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@
6767
"$ref": "#/$defs/json_config_path"
6868
},
6969

70+
"trial_config_repeat_count": {
71+
"description": "Number of times to repeat a config.",
72+
"type": "integer",
73+
"minimum": 1,
74+
"examples": [3, 5]
75+
},
76+
7077
"storage": {
7178
"description": "Path to the json config describing the storage backend to use.",
7279
"$ref": "#/$defs/json_config_path"

mlos_bench/mlos_bench/launcher.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st
7575
else:
7676
config = {}
7777

78+
self.trial_config_repeat_count = args.trial_config_repeat_count or config.get("trial_config_repeat_count", 1)
79+
7880
log_level = args.log_level or config.get("log_level", _LOG_LEVEL)
7981
try:
8082
log_level = int(log_level)
@@ -195,6 +197,10 @@ def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> T
195197
help='Path to the optimizer configuration file. If omitted, run' +
196198
' a single trial with default (or specified in --tunable_values).')
197199

200+
parser.add_argument(
201+
'--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, default=1,
202+
help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.')
203+
198204
parser.add_argument(
199205
'--storage', required=False,
200206
help='Path to the storage configuration file.' +

mlos_bench/mlos_bench/run.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def _main() -> None:
3636
storage=launcher.storage,
3737
root_env_config=launcher.root_env_config,
3838
global_config=launcher.global_config,
39-
do_teardown=launcher.teardown
39+
do_teardown=launcher.teardown,
40+
trial_config_repeat_count=launcher.trial_config_repeat_count,
4041
)
4142

4243
_LOG.info("Final result: %s", result)
@@ -48,7 +49,9 @@ def _optimize(*,
4849
storage: Storage,
4950
root_env_config: str,
5051
global_config: Dict[str, Any],
51-
do_teardown: bool) -> Tuple[Optional[float], Optional[TunableGroups]]:
52+
do_teardown: bool,
53+
trial_config_repeat_count: int = 1,
54+
) -> Tuple[Optional[float], Optional[TunableGroups]]:
5255
"""
5356
Main optimization loop.
5457
@@ -66,8 +69,13 @@ def _optimize(*,
6669
Global configuration parameters.
6770
do_teardown : bool
6871
If True, teardown the environment at the end of the experiment
72+
trial_config_repeat_count : int
73+
How many trials to repeat for the same configuration.
6974
"""
7075
# pylint: disable=too-many-locals
76+
if trial_config_repeat_count <= 0:
77+
raise ValueError(f"Invalid trial_config_repeat_count: {trial_config_repeat_count}")
78+
7179
if _LOG.isEnabledFor(logging.INFO):
7280
_LOG.info("Root Environment:\n%s", env.pprint())
7381

@@ -118,16 +126,18 @@ def _optimize(*,
118126
config_id, json.dumps(tunable_values, indent=2))
119127
config_id = -1
120128

121-
trial = exp.new_trial(tunables, config={
122-
# Add some additional metadata to track for the trial such as the
123-
# optimizer config used.
124-
# TODO: Improve for supporting multi-objective
125-
# (e.g., opt_target_1, opt_target_2, ... and opt_direction_1, opt_direction_2, ...)
126-
"optimizer": opt.name,
127-
"opt_target": opt.target,
128-
"opt_direction": opt.direction,
129-
})
130-
_run(env_context, opt_context, trial, global_config)
129+
for repeat_i in range(1, trial_config_repeat_count + 1):
130+
trial = exp.new_trial(tunables, config={
131+
# Add some additional metadata to track for the trial such as the
132+
# optimizer config used.
133+
# TODO: Improve for supporting multi-objective
134+
# (e.g., opt_target_1, opt_target_2, ... and opt_direction_1, opt_direction_2, ...)
135+
"optimizer": opt.name,
136+
"opt_target": opt.target,
137+
"opt_direction": opt.direction,
138+
"repeat_i": repeat_i,
139+
})
140+
_run(env_context, opt_context, trial, global_config)
131141

132142
if do_teardown:
133143
env_context.teardown()

mlos_bench/mlos_bench/tests/config/schemas/cli/test-cases/good/full/full-cli.jsonc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"optimizer": "optimizers/one_shot_opt.jsonc",
1616
"storage": "storage/sqlite.jsonc",
1717

18+
"trial_config_repeat_count": 3,
19+
1820
"random_init": true,
1921
"random_seed": 42,
2022

mlos_bench/mlos_bench/tests/launcher_run_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecServic
9393
"""
9494
_launch_main_app(
9595
root_path, local_exec_service,
96-
"--config mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc --max_iterations 3",
96+
"--config mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc --trial_config_repeat_count 3 --max_iterations 3",
9797
[
9898
# Iteration 1: Expect first value to be the baseline
9999
f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " +
@@ -106,6 +106,6 @@ def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecServic
106106
r"register DEBUG Score: \d+\.\d+ Dataframe:\s*$",
107107
# Final result: baseline is the optimum for the mock environment
108108
f"^{_RE_DATE} run\\.py:\\d+ " +
109-
r"_optimize INFO Env: Mock environment best score: 64\.88\d+\s*$",
109+
r"_optimize INFO Env: Mock environment best score: 64\.53\d+\s*$",
110110
]
111111
)

0 commit comments

Comments
 (0)