Skip to content

Commit b7e1998

Browse files
shell scheduler: --use-srun option
1 parent 9c97feb commit b7e1998

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

pytest_parallel/plugin.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ def pytest_addoption(parser):
3535

3636
if sys.version_info >= (3,9):
3737
parser.addoption('--slurm-export-env', dest='slurm_export_env', action=argparse.BooleanOptionalAction, default=True)
38+
parser.addoption('--use-srun', dest='use_srun', action=argparse.BooleanOptionalAction, default=False, help='Launch MPI processes through srun (only possible when `--scheduler=shell`')
3839
else:
3940
parser.addoption('--slurm-export-env', dest='slurm_export_env', default=False, action='store_true')
4041
parser.addoption('--no-slurm-export-env', dest='slurm_export_env', action='store_false')
42+
parser.addoption('--use-srun', dest='use_srun', default=True, action='store_false')
43+
parser.addoption('--no-use-srun', dest='use_srun', action='store_true')
4144

4245
parser.addoption('--detach', dest='detach', action='store_true', help='Detach SLURM jobs: do not send reports to the scheduling process (useful to launch slurm job.sh separately)')
4346

@@ -104,7 +107,11 @@ def pytest_configure(config):
104107
is_worker = config.getoption('_worker')
105108
slurm_file = config.getoption('slurm_file')
106109
slurm_export_env = config.getoption('slurm_export_env')
110+
use_srun = config.getoption('use_srun')
107111
detach = config.getoption('detach')
112+
if scheduler != 'shell':
113+
if use_srun is not None:
114+
raise PytestParallelError('Option `--use-srun` only available when `--scheduler=shell`')
108115
if not scheduler in ['slurm', 'shell']:
109116
assert not is_worker, f'Internal pytest_parallel error `--_worker` not available with`--scheduler={scheduler}`'
110117
assert not n_workers, f'pytest_parallel error `--n-workers` not available with`--scheduler={scheduler}`. Launch with `mpirun -np {n_workers}` to run in parallel'
@@ -175,7 +182,7 @@ def pytest_configure(config):
175182
main_invoke_params = _invoke_params(config.invocation_params.args)
176183
for file_or_dir in config.option.file_or_dir:
177184
main_invoke_params = main_invoke_params.replace(file_or_dir, '')
178-
plugin = ShellStaticScheduler(main_invoke_params, n_workers, detach)
185+
plugin = ShellStaticScheduler(main_invoke_params, n_workers, detach, use_srun)
179186
else:
180187
from mpi4py import MPI
181188
from .mpi_reporter import SequentialScheduler, StaticScheduler, DynamicScheduler

pytest_parallel/shell_static_scheduler.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,22 @@
1313
from .utils.file import remove_exotic_chars, create_folders
1414
from .static_scheduler_utils import group_items_by_parallel_steps
1515

16-
def mpi_command(current_proc, n_proc):
17-
mpi_vendor = MPI.get_vendor()[0]
18-
if mpi_vendor == 'Intel MPI':
19-
cmd = f'I_MPI_PIN_PROCESSOR_LIST={current_proc}-{current_proc+n_proc-1}; '
20-
cmd += f'mpiexec -np {n_proc}'
21-
return cmd
22-
elif mpi_vendor == 'Open MPI':
23-
cores = ','.join([str(i) for i in range(current_proc,current_proc+n_proc)])
24-
return f'mpiexec --cpu-list {cores} -np {n_proc}'
16+
def mpi_command(current_proc, n_proc, use_srun):
17+
if use_srun:
18+
return f'srun --exact --ntasks={n_proc}'
2519
else:
26-
assert 0, f'Unknown MPI implementation "{mpi_vendor}"'
20+
mpi_vendor = MPI.get_vendor()[0]
21+
if mpi_vendor == 'Intel MPI':
22+
cmd = f'I_MPI_PIN_PROCESSOR_LIST={current_proc}-{current_proc+n_proc-1}; '
23+
cmd += f'mpiexec -np {n_proc}'
24+
return cmd
25+
elif mpi_vendor == 'Open MPI':
26+
cores = ','.join([str(i) for i in range(current_proc,current_proc+n_proc)])
27+
return f'mpiexec --cpu-list {cores} -np {n_proc}'
28+
else:
29+
assert 0, f'Unknown MPI implementation "{mpi_vendor}"'
2730

28-
def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, session_folder, main_invoke_params, i_step, n_step):
31+
def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, session_folder, main_invoke_params, use_srun, i_step, n_step):
2932
# sort item by comm size to launch bigger first (Note: in case SLURM prioritize first-received items)
3033
items = sorted(items_to_run, key=lambda item: item.n_proc, reverse=True)
3134

@@ -40,7 +43,7 @@ def submit_items(items_to_run, SCHEDULER_IP_ADDRESS, port, session_folder, main_
4043
test_idx = item.original_index
4144
test_out_file = f'.pytest_parallel/{session_folder}/{remove_exotic_chars(item.nodeid)}'
4245
cmd = '('
43-
cmd += mpi_command(current_proc, item.n_proc)
46+
cmd += mpi_command(current_proc, item.n_proc, use_srun)
4447
cmd += f' python3 -u -m pytest -s --_worker {socket_flags} {main_invoke_params} --_test_idx={test_idx} {item.config.rootpath}/{item.nodeid}'
4548
cmd += f' > {test_out_file} 2>&1'
4649
cmd += f' ; python3 -m pytest_parallel.send_report {socket_flags} --_test_idx={test_idx} --_test_name={test_out_file}'
@@ -99,10 +102,11 @@ def receive_items(items, session, socket, n_item_to_recv):
99102
n_item_to_recv -= 1
100103

101104
class ShellStaticScheduler:
102-
def __init__(self, main_invoke_params, ntasks, detach):
105+
def __init__(self, main_invoke_params, ntasks, detach, use_srun):
103106
self.main_invoke_params = main_invoke_params
104107
self.ntasks = ntasks
105108
self.detach = detach
109+
self.use_srun = use_srun
106110

107111
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # TODO close at the end
108112

@@ -148,7 +152,7 @@ def pytest_runtestloop(self, session) -> bool:
148152
n_step = len(items_by_steps)
149153
for i_step,items in enumerate(items_by_steps):
150154
n_item_to_receive = len(items)
151-
sub_process = submit_items(items, SCHEDULER_IP_ADDRESS, port, session_folder, self.main_invoke_params, i_step, n_step)
155+
sub_process = submit_items(items, SCHEDULER_IP_ADDRESS, port, session_folder, self.main_invoke_params, self.use_srun, i_step, n_step)
152156
if not self.detach: # The job steps are supposed to send their reports
153157
receive_items(session.items, session, self.socket, n_item_to_receive)
154158
returncode = sub_process.wait() # at this point, the sub-process should be done since items have been received

0 commit comments

Comments
 (0)