55"""
66# std imports
77from argparse import ArgumentParser
8+ import contextlib
89import json
910import logging
1011import os
11- import tempfile
1212from typing import Optional
1313
1414# tpl imports
@@ -30,8 +30,11 @@ def get_args():
3030 parser .add_argument ("input_json" , type = str , help = "Input JSON file containing the test cases." )
3131 parser .add_argument ("-o" , "--output" , type = str , help = "Output JSON file containing the results." )
3232 parser .add_argument ("--scratch-dir" , type = str , help = "If provided, put scratch files here." )
33+ parser .add_argument ("--driver-root" , type = str , help = "Where to look for the driver files, if not in cwd." )
3334 parser .add_argument ("--launch-configs" , type = str , default = "launch-configs.json" ,
3435 help = "config for how to run samples." )
36+ parser .add_argument ("--build-configs" , type = str , default = "build-configs.json" ,
37+ help = "config for how to build samples. If not provided, will use the default build settings for each model." )
3538 parser .add_argument ("--problem-sizes" , type = str , default = "problem-sizes.json" ,
3639 help = "config for how to run samples." )
3740 parser .add_argument ("--yes-to-all" , action = "store_true" , help = "If provided, automatically answer yes to all prompts." )
@@ -56,11 +59,19 @@ def get_args():
5659 parser .add_argument ("--log-runs" , action = "store_true" , help = "Display the stderr and stdout of runs." )
5760 return parser .parse_args ()
5861
59- def get_driver (prompt : dict , scratch_dir : Optional [os .PathLike ], launch_configs : dict , problem_sizes : dict , dry : bool , ** kwargs ) -> DriverWrapper :
62+ def get_driver (
63+ prompt : dict ,
64+ scratch_dir : Optional [os .PathLike ],
65+ launch_configs : dict ,
66+ build_configs : dict ,
67+ problem_sizes : dict ,
68+ dry : bool ,
69+ ** kwargs
70+ ) -> DriverWrapper :
6071 """ Get the language drive wrapper for this prompt """
6172 driver_cls = LANGUAGE_DRIVERS [prompt ["language" ]]
6273 return driver_cls (parallelism_model = prompt ["parallelism_model" ], launch_configs = launch_configs ,
63- problem_sizes = problem_sizes , scratch_dir = scratch_dir , dry = dry , ** kwargs )
74+ build_configs = build_configs , problem_sizes = problem_sizes , scratch_dir = scratch_dir , dry = dry , ** kwargs )
6475
6576def already_has_results (prompt : dict ) -> bool :
6677 """ Check if a prompt already has results stored in it. """
@@ -102,10 +113,25 @@ def main():
102113 launch_configs = load_json (args .launch_configs )
103114 logging .info (f"Loaded launch configs from { args .launch_configs } ." )
104115
116+ # load build configs
117+ build_configs = load_json (args .build_configs )
118+ logging .info (f"Loaded build configs from { args .build_configs } ." )
119+
105120 # load problem sizes
106121 problem_sizes = load_json (args .problem_sizes )
107122 logging .info (f"Loaded problem sizes from { args .problem_sizes } ." )
108123
124+ # set driver root; If provided, use user argument. If it's not provided, then check if the PAREVAL_ROOT environment
125+ # variable is set, then use "${PAREVAL_ROOT}/drivers" as the root. If neither is set, then use the location of
126+ # this script as the root.
127+ if args .driver_root :
128+ DRIVER_ROOT = args .driver_root
129+ elif "PAREVAL_ROOT" in os .environ :
130+ DRIVER_ROOT = os .path .join (os .environ ["PAREVAL_ROOT" ], "drivers" )
131+ else :
132+ DRIVER_ROOT = os .path .dirname (os .path .abspath (__file__ ))
133+ logging .info (f"Using driver root: { DRIVER_ROOT } " )
134+
109135 # gather the list of parallelism models to test
110136 models_to_test = args .include_models if args .include_models else ["serial" , "omp" , "mpi" , "mpi+omp" , "kokkos" , "cuda" , "hip" ]
111137 if args .exclude_models :
@@ -139,15 +165,18 @@ def main():
139165 prompt ,
140166 args .scratch_dir ,
141167 launch_configs ,
168+ build_configs ,
142169 problem_sizes ,
143170 args .dry ,
144171 display_build_errors = args .log_build_errors ,
145172 display_runs = args .log_runs ,
146173 early_exit_runs = args .early_exit_runs ,
147174 build_timeout = args .build_timeout ,
148- run_timeout = args .run_timeout
175+ run_timeout = args .run_timeout ,
149176 )
150- driver .test_all_outputs_in_prompt (prompt )
177+
178+ with contextlib .chdir (DRIVER_ROOT ):
179+ driver .test_all_outputs_in_prompt (prompt )
151180
152181 # go ahead and write out outputs now
153182 if args .output and args .output != '-' :
0 commit comments