Skip to content

Commit 30dc29f

Browse files
committed
fix bug
1 parent 898f203 commit 30dc29f

File tree

3 files changed

+25
-32
lines changed

3 files changed

+25
-32
lines changed

src/fairchem/core/_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def main(
124124
# the hands all responsibility the user, ie they must initialize ray
125125
runner: Runner = hydra.utils.instantiate(cfg.runner, _recursive_=False)
126126
runner.run()
127-
elif scheduler_cfg.ranks_per_node > 1:
127+
else:
128128
from fairchem.core.launchers.slurm_launch import local_launch
129129

130130
# else launch locally using torch elastic or local mode

src/fairchem/core/launchers/slurm_launch.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,25 +269,22 @@ def local_launch(cfg: DictConfig, log_dir: str):
269269
Launch locally with torch elastic (for >1 workers) or just single process
270270
"""
271271
scheduler_cfg = cfg.job.scheduler
272-
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
273-
274-
from fairchem.core.launchers import slurm_launch
275-
276-
launch_config = LaunchConfig(
277-
min_nodes=1,
278-
max_nodes=1,
279-
nproc_per_node=scheduler_cfg.ranks_per_node,
280-
rdzv_backend="c10d",
281-
max_restarts=0,
282-
)
283-
elastic_launch(launch_config, slurm_launch.runner_wrapper)(cfg)
284-
if "reducer" in cfg:
285-
elastic_launch(launch_config, slurm_launch.runner_wrapper)(cfg, RunType.REDUCE)
272+
if scheduler_cfg.ranks_per_node > 1:
273+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
274+
275+
launch_config = LaunchConfig(
276+
min_nodes=1,
277+
max_nodes=1,
278+
nproc_per_node=scheduler_cfg.ranks_per_node,
279+
rdzv_backend="c10d",
280+
max_restarts=0,
281+
)
282+
elastic_launch(launch_config, runner_wrapper)(cfg)
283+
if "reducer" in cfg:
284+
elastic_launch(launch_config, runner_wrapper)(cfg, RunType.REDUCE)
286285
else:
287286
logging.info("Running in local mode without elastic launch")
288-
from fairchem.core.launchers import slurm_launch
289-
290287
distutils.setup_env_local()
291-
slurm_launch.runner_wrapper(cfg)
288+
runner_wrapper(cfg)
292289
if "reducer" in cfg:
293-
slurm_launch.runner_wrapper(cfg, RunType.REDUCE)
290+
runner_wrapper(cfg, RunType.REDUCE)

tests/core/conftest.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,10 @@
66
"""
77

88
from __future__ import annotations
9-
from fairchem.core.units.mlip_unit.mlip_unit import (
10-
UNIT_INFERENCE_CHECKPOINT,
11-
UNIT_RESUME_CONFIG,
12-
)
13-
149

15-
from tests.core.units.mlip_unit.create_fake_dataset import (
16-
create_fake_uma_dataset,
17-
)
1810
import os
1911
import tempfile
20-
21-
from tests.core.testing_utils import launch_main
2212
from itertools import product
23-
import logging
2413
from random import choice
2514
from typing import TYPE_CHECKING
2615

@@ -33,6 +22,14 @@
3322
from syrupy.extensions.amber import AmberSnapshotExtension
3423

3524
from fairchem.core.datasets import AseDBDataset
25+
from fairchem.core.units.mlip_unit.mlip_unit import (
26+
UNIT_INFERENCE_CHECKPOINT,
27+
UNIT_RESUME_CONFIG,
28+
)
29+
from tests.core.testing_utils import launch_main
30+
from tests.core.units.mlip_unit.create_fake_dataset import (
31+
create_fake_uma_dataset,
32+
)
3633

3734
if TYPE_CHECKING:
3835
from syrupy.types import SerializableData
@@ -222,7 +219,7 @@ def dummy_binary_dataset(dummy_binary_dataset_path):
222219
def run_around_tests():
223220
# If debugging GPU memory issues, uncomment this print statement
224221
# to get full GPU memory allocations before each test runs
225-
#print(torch.cuda.memory_summary())
222+
# print(torch.cuda.memory_summary())
226223
yield
227224
torch.cuda.empty_cache()
228225

@@ -343,7 +340,6 @@ def conserving_mole_checkpoint(fake_uma_dataset):
343340
return inference_checkpoint_pt, checkpoint_state_yaml
344341

345342

346-
347343
@pytest.fixture(scope="session")
348344
def fake_uma_dataset():
349345
with tempfile.TemporaryDirectory() as tempdirname:

0 commit comments

Comments
 (0)