diff --git a/configs/uma/training_release/uma_ray_demo.yaml b/configs/uma/training_release/uma_ray_demo.yaml new file mode 100644 index 0000000000..52666af74e --- /dev/null +++ b/configs/uma/training_release/uma_ray_demo.yaml @@ -0,0 +1,186 @@ +# example yaml of using the Ray on slurm launcher to run the UMA training job +# Logger and checkpointing is current not enabled + +defaults: + - cluster: h100 + - backbone: K4L2 + - dataset: uma + - element_refs: uma_v1_hof_lin_refs + - tasks: uma_direct + - _self_ + +job: + device_type: ${cluster.device} + scheduler: + use_ray: true + mode: ${cluster.mode} + ranks_per_node: ${cluster.ranks_per_node} + num_nodes: 1 + slurm: + account: ${cluster.account} + qos: ${cluster.qos} + mem_gb: ${cluster.mem_gb} + cpus_per_task: ${cluster.cpus_per_task} + debug: ${cluster.debug} + run_dir: ${cluster.run_dir} + run_name: uma_sm_direct + # logger: + # _target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb + # _partial_: true + # entity: fairchem + # project: uma + +moe_layer_type: pytorch +num_moe_experts: 32 +max_neighbors: 30 +cutoff_radius: 6 +epochs: null +steps: 1680000 # 140B atoms, 128 ranks, max atoms 700 (mean atoms 650) +max_atoms: 700 +bf16: True +cpu_graph: True +otf_graph: False +normalizer_rmsd: 1.423 +direct_forces_coef: 30 +omc_energy_coef: 10 +omol_energy_coef: 30 +odac_energy_coef: 10 +oc20_energy_coef: 10 +omat_energy_coef: 10 + +regress_stress: False +direct_forces: True + +oc20_forces_key: forces +omat_forces_key: forces +omol_forces_key: forces +odac_forces_key: forces +omc_forces_key: forces + +dataset_list: ["oc20", "omol", "omat", "odac", "omc"] + +exclude_keys: [ + "id", # only oc20,oc22 have this + "fid", # only oc20,oc22 have this + "absolute_idx", # only ani has this + "target_pos", # only ani has this + "ref_energy", # only ani/geom have this + "pbc", # only ani/transition1x have this + "nads", # oc22 + "oc22", # oc22 + "formation_energy", # spice + "total_charge", # spice +] + +train_dataset: + _target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset + dataset_configs: + omc: ${dataset.omc_train} + omol: ${dataset.omol_train} + odac: ${dataset.odac_train} + omat: ${dataset.omat_train} + oc20: ${dataset.oc20_train} + combined_dataset_config: + sampling: + type: explicit + ratios: + omol.train: 4.0 + oc20.train: 1.0 + omc.train: 2.0 + odac.train: 1.0 + omat.train: 2.0 + +val_dataset: + _target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset + dataset_configs: + omc: ${dataset.omc_val} + omol: ${dataset.omol_val} + odac: ${dataset.odac_val} + omat: ${dataset.omat_val} + oc20: ${dataset.oc20_val} + combined_dataset_config: { sampling: {type: temperature, temperature: 1.0} } + +train_dataloader: + _target_: fairchem.core.components.common.dataloader_builder.get_dataloader + dataset: ${train_dataset} + batch_sampler_fn: + _target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler + _partial_: True + max_atoms: ${max_atoms} + shuffle: True + seed: 0 + num_workers: ${cluster.dataloader_workers} + collate_fn: + _target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter + tasks: ${tasks} + exclude_keys: ${exclude_keys} + +eval_dataloader: + _target_: fairchem.core.components.common.dataloader_builder.get_dataloader + dataset: ${val_dataset} + batch_sampler_fn: + _target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler + _partial_: True + max_atoms: ${max_atoms} + shuffle: False + seed: 0 + num_workers: ${cluster.dataloader_workers} + collate_fn: + _target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter + tasks: ${tasks} + exclude_keys: ${exclude_keys} + +heads: + oc20_energy: + module: fairchem.core.models.uma.escn_md.MLP_Energy_Head + omat_energy: + module: fairchem.core.models.uma.escn_md.MLP_Energy_Head + omc_energy: + module: fairchem.core.models.uma.escn_md.MLP_Energy_Head + omol_energy: + module: fairchem.core.models.uma.escn_md.MLP_Energy_Head + odac_energy: + module: fairchem.core.models.uma.escn_md.MLP_Energy_Head + forces: + module: fairchem.core.models.uma.escn_md.Linear_Force_Head + +runner: + _target_: fairchem.core.launchers.ray_on_slurm_launch.SPMDController + job_config: ${job} + runner_config: + _target_: fairchem.core.components.train.train_runner.TrainEvalRunner + train_dataloader: ${train_dataloader} + eval_dataloader: ${eval_dataloader} + train_eval_unit: + _target_: fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit + job_config: ${job} + tasks: ${tasks} + model: + _target_: fairchem.core.models.base.HydraModel + backbone: ${backbone} + heads: ${heads} + optimizer_fn: + _target_: torch.optim.AdamW + _partial_: true + lr: 8e-4 + weight_decay: 1e-3 + cosine_lr_scheduler_fn: + _target_: fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler + _partial_: true + warmup_factor: 0.2 + warmup_epochs: 0.01 + lr_min_factor: 0.01 + epochs: ${epochs} + steps: ${steps} + print_every: 10 + clip_grad_norm: 100 + bf16: ${bf16} + max_epochs: ${epochs} + max_steps: ${steps} + evaluate_every_n_steps: 10000 + callbacks: + - _target_: fairchem.core.common.profiler_utils.ProfilerCallback + job_config: ${job} + # - _target_: fairchem.core.components.train.train_runner.TrainCheckpointCallback + # checkpoint_every_n_steps: 5000 + # max_saved_checkpoints: 5 diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index f46a8650d1..4253b266c6 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"] docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict", "ipywidgets"] adsorbml = ["dscribe","x3dase","scikit-image"] -extras = ["ray", "pymatgen", "quacc[phonons]>=0.15.3","pandas"] +extras = ["ray[default]", "pymatgen", "quacc[phonons]>=0.15.3", "pandas"] [project.scripts] fairchem = "fairchem.core._cli:main" diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index 68cdb4103d..b229134fd8 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -10,355 +10,28 @@ import argparse import logging import os -import random -import tempfile -import uuid -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING -import clusterscope import hydra -import numpy as np -import torch from omegaconf import OmegaConf from omegaconf.errors import InterpolationKeyError -from fairchem.core.common import gp_utils +from fairchem.core.launchers.api import ( + ALLOWED_TOP_LEVEL_KEYS, + JobConfig, + SchedulerType, +) if TYPE_CHECKING: from omegaconf import DictConfig - from fairchem.core.components.reducer import Reducer from fairchem.core.components.runner import Runner -from submitit import AutoExecutor -from submitit.core.utils import JobPaths, cloudpickle_dump -from submitit.helpers import Checkpointable, DelayedSubmission -from submitit.slurm.slurm import SlurmJobEnvironment - -from fairchem.core.common import distutils -from fairchem.core.common.logger import WandBSingletonLogger -from fairchem.core.common.utils import ( - StrEnum, - get_commit_hash, - get_timestamp_uid, - setup_env_vars, - setup_logging, -) # this effects the cli only since the actual job will be run in subprocesses or remoe logging.basicConfig(level=logging.INFO) -ALLOWED_TOP_LEVEL_KEYS = {"job", "runner", "reducer"} - -LOG_DIR_NAME = "logs" -CHECKPOINT_DIR_NAME = "checkpoints" -RESULTS_DIR = "results" -CONFIG_FILE_NAME = "canonical_config.yaml" -PREEMPTION_STATE_DIR_NAME = "preemption_state" - - -class SchedulerType(StrEnum): - LOCAL = "local" - SLURM = "slurm" - - -class DeviceType(StrEnum): - CPU = "cpu" - CUDA = "cuda" - - -class RunType(StrEnum): - RUN = "run" - REDUCE = "reduce" - - -class DistributedInitMethod(StrEnum): - TCP = "tcp" - FILE = "file" - - -@dataclass -class SlurmConfig: - mem_gb: int = 80 - timeout_hr: int = 168 - cpus_per_task: int = 8 - partition: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - qos: Optional[str] = None # omegaconf in python 3.9 does not backport annotations - account: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - additional_parameters: Optional[dict] = None - - -@dataclass -class SchedulerConfig: - mode: SchedulerType = SchedulerType.LOCAL - distributed_init_method: DistributedInitMethod = DistributedInitMethod.TCP - ranks_per_node: int = 1 - num_nodes: int = 1 - num_array_jobs: int = 1 - slurm: SlurmConfig = field(default_factory=lambda: SlurmConfig) - - -@dataclass -class SlurmEnv: - # reflects the job_id given by submitit (slurm id with array job id and array task id if they exist) - job_id: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - # reflects SLURM_JOB_ID only - raw_job_id: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - # SLURM_ARRAY_JOB_ID - array_job_id: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - # SLURM_ARRAY_TASK_ID - array_task_id: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - # reflects SLURM_RESTART_COUNT env variable - restart_count: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - - -@dataclass -class Metadata: - # read-only metadata about the job, not user inputs - commit: str - log_dir: str - checkpoint_dir: str - results_dir: str - config_path: str - preemption_checkpoint_dir: str - cluster_name: str - array_job_num: int = 0 - slurm_env: SlurmEnv = field(default_factory=lambda: SlurmEnv()) - - -@dataclass -class JobConfig: - run_name: str = field( - default_factory=lambda: get_timestamp_uid() + uuid.uuid4().hex.upper()[0:4] - ) - timestamp_id: str = field(default_factory=lambda: get_timestamp_uid()) - run_dir: str = field(default_factory=lambda: tempfile.TemporaryDirectory().name) - device_type: DeviceType = DeviceType.CUDA - debug: bool = False - scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig) - logger: Optional[dict] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - seed: int = 0 - deterministic: bool = False - runner_state_path: Optional[str] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - # read-only metadata about the job, not user inputs - metadata: Optional[Metadata] = ( - None # omegaconf in python 3.9 does not backport annotations - ) - graph_parallel_group_size: Optional[int] = None - - def __post_init__(self) -> None: - self.run_dir = os.path.abspath(self.run_dir) - try: - cluster = clusterscope.cluster() - except RuntimeError: - cluster = "" - self.metadata = Metadata( - commit=get_commit_hash(), - log_dir=os.path.join(self.run_dir, self.timestamp_id, LOG_DIR_NAME), - checkpoint_dir=os.path.join( - self.run_dir, self.timestamp_id, CHECKPOINT_DIR_NAME - ), - results_dir=os.path.join(self.run_dir, self.timestamp_id, RESULTS_DIR), - config_path=os.path.join(self.run_dir, self.timestamp_id, CONFIG_FILE_NAME), - preemption_checkpoint_dir=os.path.join( - self.run_dir, - self.timestamp_id, - CHECKPOINT_DIR_NAME, - PREEMPTION_STATE_DIR_NAME, - ), - cluster_name=cluster, - ) - - -def _set_seeds(seed: int) -> None: - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def _set_deterministic_mode() -> None: - # this is required for full cuda deterministic mode - logging.info("Setting deterministic mode!") - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.use_deterministic_algorithms(True) - - -def _get_slurm_env() -> SlurmEnv: - slurm_job_env = SlurmJobEnvironment() - try: - slurm_env = SlurmEnv( - job_id=slurm_job_env.job_id, - raw_job_id=slurm_job_env.raw_job_id, - array_job_id=slurm_job_env.array_job_id, - array_task_id=slurm_job_env.array_task_id, - restart_count=os.environ.get("SLURM_RESTART_COUNT"), - ) - except KeyError: - # slurm environment variables are undefined, running locally - slurm_env = SlurmEnv() - - return slurm_env - - -def remove_runner_state_from_submission(log_folder: str, job_id: str) -> None: - # (HACK) Decouple the job from the runner state by manually modifying it - # this ensures the saved runner state is not re-submitted in the event of a node failure - # ie: if the job was started at state t=T, a requeue during node failure would resubmit the job - # starting at state t=T again without calling the checkpoint callback, losing all progress in between. - job_path = JobPaths(folder=log_folder, job_id=job_id) - if os.path.isfile(job_path.submitted_pickle): - submission_obj = DelayedSubmission.load(job_path.submitted_pickle) - submission_obj.args[0].job.runner_state_path = None - cloudpickle_dump(submission_obj, job_path.submitted_pickle) - - -class Submitit(Checkpointable): - def __init__(self) -> None: - self.config = None - self.runner = None - self.reducer = None - - def __call__( - self, dict_config: DictConfig, run_type: RunType = RunType.RUN - ) -> None: - self.config = dict_config - self.run_type = run_type - # modify the config metadata to add slurm info if they exist - self.config.job.metadata.slurm_env = _get_slurm_env() - - setup_env_vars() - setup_logging() - - dist_config = map_job_config_to_dist_config(self.config.job) - logging.info("Setting up distributed backend...") - distutils.setup(dist_config) - distutils.synchronize() - if ( - distutils.is_master() - and self.config.job.scheduler.mode == SchedulerType.SLURM - ): - # this pickle file is shared across all processes so can only modify this on the main rank - remove_runner_state_from_submission( - dict_config.job.metadata.log_dir, - self.config.job.metadata.slurm_env.job_id, - ) - - if self.config.job.graph_parallel_group_size is not None: - logging.info("Setting up graph parallel...") - gp_utils.setup_graph_parallel_groups( - self.config.job.graph_parallel_group_size, - dist_config["distributed_backend"], - ) - - self._init_logger() - - _set_seeds(self.config.job.seed) - if self.config.job.deterministic: - _set_deterministic_mode() - - if run_type == RunType.RUN: - logging.info("Calling runner.run() ...") - self.runner: Runner = hydra.utils.instantiate(self.config.runner) - self.runner.job_config = self.config.job - # must call resume state AFTER the runner has been initialized - self.runner.load_state(self.config.job.runner_state_path) - self.runner.run() - elif run_type == RunType.REDUCE: - logging.info("Calling reducer.reduce() ...") - self.reducer: Reducer = hydra.utils.instantiate(self.config.reducer) - self.reducer.job_config = self.config.job - self.reducer.runner_config = self.config.runner - # must call resume state AFTER the runner has been initialized - self.reducer.load_state(self.config.job.runner_state_path) - self.reducer.reduce() - else: - raise ValueError(f"run type {run_type} is not recognized!") - - distutils.cleanup() - - def _init_logger(self) -> None: - if ( - self.config.job.logger - and distutils.is_master() - and not self.config.job.debug - and self.config.job.metadata.array_job_num == 0 - ): - # get a partial function from the config and instantiate wandb with it - # currently code assumes that we only use the WandBSingletonLogger - logger_initializer = hydra.utils.instantiate(self.config.job.logger) - simple_config = OmegaConf.to_container( - self.config, resolve=True, throw_on_missing=True - ) - logger_initializer( - config=simple_config, - run_id=self.config.job.timestamp_id, - run_name=self.config.job.run_name, - log_dir=self.config.job.metadata.log_dir, - ) - - def checkpoint(self, *args, **kwargs) -> DelayedSubmission: - logging.error("Submitit checkpointing callback is triggered") - save_path = self.config.job.metadata.preemption_checkpoint_dir - cfg_copy = self.config.copy() - # only assign if the save was successful - cfg_copy.job.runner_state_path = None - - if ( - self.run_type == RunType.RUN - and self.runner.save_state(save_path, is_preemption=True) - ) or ( - self.run_type == RunType.REDUCE - and self.reducer.save_state(save_path, is_preemption=True) - ): - cfg_copy.job.runner_state_path = save_path - - if WandBSingletonLogger.initialized(): - WandBSingletonLogger.get_instance().mark_preempting() - logging.info( - f"Submitit checkpointing callback is completed, resuming with use the following state: {save_path}" - ) - return DelayedSubmission(Submitit(), cfg_copy) - - -def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict: - scheduler_config = job_cfg.scheduler - return { - "world_size": scheduler_config.num_nodes * scheduler_config.ranks_per_node, - "distributed_backend": ( - "gloo" if job_cfg.device_type == DeviceType.CPU else "nccl" - ), - "submit": scheduler_config.mode == SchedulerType.SLURM, - "cpu": job_cfg.device_type == DeviceType.CPU, - "init_method": scheduler_config.distributed_init_method, - # for distributed shared file initialization - "shared_file_dir": os.path.join(job_cfg.run_dir, job_cfg.timestamp_id), - "array_job_num": job_cfg.metadata.array_job_num, - } - - def get_canonical_config(config: DictConfig) -> DictConfig: # manually initialize metadata, because OmegaConf currently doesn't call __post_init__ on dataclasses job = OmegaConf.to_object(config.job) @@ -405,12 +78,6 @@ def get_hydra_config_from_yaml( return get_canonical_config(cfg) -def _runner_wrapper(config: DictConfig, run_type: RunType = RunType.RUN): - # This is needed when using elastic_launch for local runs since it looks for - # the __name__ attribute of the function, Submitit.__call__ does not have one - Submitit()(config, run_type) - - def main( args: argparse.Namespace | None = None, override_args: list[str] | None = None ): @@ -433,77 +100,37 @@ def main( assert ( os.getenv("SLURM_SUBMIT_HOST") is None ), "SLURM DID NOT SUBMIT JOB!! Please do not submit jobs from an active slurm job (srun or otherwise)" - executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) - executor.update_parameters( - name=cfg.job.run_name, - mem_gb=scheduler_cfg.slurm.mem_gb, - timeout_min=scheduler_cfg.slurm.timeout_hr * 60, - slurm_partition=scheduler_cfg.slurm.partition, - gpus_per_node=scheduler_cfg.ranks_per_node, - cpus_per_task=scheduler_cfg.slurm.cpus_per_task, - tasks_per_node=scheduler_cfg.ranks_per_node, - nodes=scheduler_cfg.num_nodes, - slurm_qos=scheduler_cfg.slurm.qos, - slurm_account=scheduler_cfg.slurm.account, - slurm_additional_parameters=scheduler_cfg.slurm.additional_parameters, - ) - if scheduler_cfg.num_array_jobs == 1: - job = executor.submit(Submitit(), cfg) - logging.info( - f"Submitted job id: {cfg.job.timestamp_id}, slurm id: {job.job_id}, logs: {cfg.job.metadata.log_dir}" - ) - jobs = [job] - elif scheduler_cfg.num_array_jobs > 1: - executor.update_parameters( - slurm_array_parallelism=scheduler_cfg.num_array_jobs, - ) - jobs = [] - with executor.batch(): - for job_number in range(scheduler_cfg.num_array_jobs): - _cfg = cfg.copy() - _cfg.job.metadata.array_job_num = job_number - job = executor.submit(Submitit(), _cfg) - jobs.append(job) - logging.info(f"Submitted {len(jobs)} jobs: {jobs[0].job_id.split('_')[0]}") + if scheduler_cfg.use_ray: + logging.info("Lauching job on Ray + Slurm cluster") + from fairchem.core.launchers import ray_on_slurm_launch - if "reducer" in cfg: - job_id = jobs[0].job_id.split("_")[0] - executor.update_parameters( - name=f"{cfg.job.run_name}_reduce", - # set a single node, or do we want the same config as the Runner or a separate JobConfig - nodes=1, - slurm_dependency=f"afterok:{job_id}", - slurm_additional_parameters={ - "kill-on-invalid-dep": "yes" - }, # kill the reducer if run fails - ) - executor.submit(Submitit(), cfg, RunType.REDUCE) - else: - from torch.distributed.launcher.api import LaunchConfig, elastic_launch + ray_on_slurm_launch.ray_on_slurm_launch(cfg, log_dir) + else: + logging.info("Lauching job on directly on Slurm cluster") + from fairchem.core.launchers import slurm_launch + slurm_launch.slurm_launch(cfg, log_dir) + elif scheduler_cfg.mode == SchedulerType.LOCAL: # Run locally if scheduler_cfg.num_nodes > 1: cfg.job.scheduler.num_nodes = 1 logging.warning( f"You cannot use more than one node (scheduler_cfg.num_nodes={scheduler_cfg.num_nodes}) in LOCAL mode, over-riding to 1 node" ) - if scheduler_cfg.ranks_per_node > 1: + # if using ray, then launch ray cluster locally + if scheduler_cfg.use_ray: + logging.info("Running in local mode with local ray cluster") + # don't recursively instantiate the runner here to allow lazy instantiations in the runner + # the hands all responsibility the user, ie they must initialize ray + runner: Runner = hydra.utils.instantiate(cfg.runner, _recursive_=False) + runner.run() + else: + from fairchem.core.launchers.slurm_launch import local_launch + + # else launch locally using torch elastic or local mode logging.info( f"Running in local mode with {scheduler_cfg.ranks_per_node} ranks using device_type:{cfg.job.device_type}" ) - launch_config = LaunchConfig( - min_nodes=1, - max_nodes=1, - nproc_per_node=scheduler_cfg.ranks_per_node, - rdzv_backend="c10d", - max_restarts=0, - ) - elastic_launch(launch_config, _runner_wrapper)(cfg) - if "reducer" in cfg: - elastic_launch(launch_config, _runner_wrapper)(cfg, RunType.REDUCE) - else: - logging.info("Running in local mode without elastic launch") - distutils.setup_env_local() - Submitit()(cfg) - if "reducer" in cfg: - Submitit()(cfg, RunType.REDUCE) + local_launch(cfg, log_dir) + else: + raise ValueError(f"Unknown scheduler mode {scheduler_cfg.mode}") diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 8b20ccdf07..f497897cac 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -272,8 +272,8 @@ def setup_env_local(): os.environ["MASTER_PORT"] = str(get_free_port()) -def setup_env_local_multi_gpu(rank: int, port: int): - os.environ["MASTER_ADDR"] = "localhost" +def setup_env_local_multi_gpu(rank: int, port: int, address: str = "localhost"): + os.environ["MASTER_ADDR"] = address os.environ["LOCAL_RANK"] = str(rank) os.environ["RANK"] = str(rank) os.environ["MASTER_PORT"] = str(port) diff --git a/src/fairchem/core/launchers/api.py b/src/fairchem/core/launchers/api.py new file mode 100644 index 0000000000..c60169ff29 --- /dev/null +++ b/src/fairchem/core/launchers/api.py @@ -0,0 +1,169 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import os +import tempfile +import uuid +from dataclasses import dataclass, field +from typing import Optional + +import clusterscope + +from fairchem.core.common.utils import ( + StrEnum, + get_commit_hash, + get_timestamp_uid, +) + +ALLOWED_TOP_LEVEL_KEYS = {"job", "runner", "reducer"} + +LOG_DIR_NAME = "logs" +CHECKPOINT_DIR_NAME = "checkpoints" +RESULTS_DIR = "results" +CONFIG_FILE_NAME = "canonical_config.yaml" +PREEMPTION_STATE_DIR_NAME = "preemption_state" + + +class SchedulerType(StrEnum): + LOCAL = "local" + SLURM = "slurm" + + +class DeviceType(StrEnum): + CPU = "cpu" + CUDA = "cuda" + + +class RunType(StrEnum): + RUN = "run" + REDUCE = "reduce" + + +class DistributedInitMethod(StrEnum): + TCP = "tcp" + FILE = "file" + + +@dataclass +class SlurmConfig: + mem_gb: int = 80 + timeout_hr: int = 168 + cpus_per_task: int = 8 + partition: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + qos: Optional[str] = None # omegaconf in python 3.9 does not backport annotations + account: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + additional_parameters: Optional[dict] = None + + +@dataclass +class RayClusterConfig: + head_gpus: int = 0 + + +@dataclass +class SchedulerConfig: + mode: SchedulerType = SchedulerType.LOCAL + distributed_init_method: DistributedInitMethod = DistributedInitMethod.TCP + ranks_per_node: int = 1 + num_nodes: int = 1 + num_array_jobs: int = 1 + slurm: SlurmConfig = field(default_factory=lambda: SlurmConfig()) + # if not None, will launch a ray cluster on slurm instead of using submitit directly to launch the job + use_ray: bool = False + ray_cluster: RayClusterConfig = field(default_factory=lambda: RayClusterConfig()) + + +@dataclass +class SlurmEnv: + # reflects the job_id given by submitit (slurm id with array job id and array task id if they exist) + job_id: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + # reflects SLURM_JOB_ID only + raw_job_id: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + # SLURM_ARRAY_JOB_ID + array_job_id: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + # SLURM_ARRAY_TASK_ID + array_task_id: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + # reflects SLURM_RESTART_COUNT env variable + restart_count: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + + +@dataclass +class Metadata: + # read-only metadata about the job, not user inputs + commit: str + log_dir: str + checkpoint_dir: str + results_dir: str + config_path: str + preemption_checkpoint_dir: str + cluster_name: str + array_job_num: int = 0 + slurm_env: SlurmEnv = field(default_factory=lambda: SlurmEnv()) + + +@dataclass +class JobConfig: + run_name: str = field( + default_factory=lambda: get_timestamp_uid() + uuid.uuid4().hex.upper()[0:4] + ) + timestamp_id: str = field(default_factory=lambda: get_timestamp_uid()) + run_dir: str = field(default_factory=lambda: tempfile.TemporaryDirectory().name) + device_type: DeviceType = DeviceType.CUDA + debug: bool = False + scheduler: SchedulerConfig = field(default_factory=lambda: SchedulerConfig) + logger: Optional[dict] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + seed: int = 0 + deterministic: bool = False + runner_state_path: Optional[str] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + # read-only metadata about the job, not user inputs + metadata: Optional[Metadata] = ( + None # omegaconf in python 3.9 does not backport annotations + ) + graph_parallel_group_size: Optional[int] = None + + def __post_init__(self) -> None: + self.run_dir = os.path.abspath(self.run_dir) + try: + cluster = clusterscope.cluster() + except RuntimeError: + cluster = "" + self.metadata = Metadata( + commit=get_commit_hash(), + log_dir=os.path.join(self.run_dir, self.timestamp_id, LOG_DIR_NAME), + checkpoint_dir=os.path.join( + self.run_dir, self.timestamp_id, CHECKPOINT_DIR_NAME + ), + results_dir=os.path.join(self.run_dir, self.timestamp_id, RESULTS_DIR), + config_path=os.path.join(self.run_dir, self.timestamp_id, CONFIG_FILE_NAME), + preemption_checkpoint_dir=os.path.join( + self.run_dir, + self.timestamp_id, + CHECKPOINT_DIR_NAME, + PREEMPTION_STATE_DIR_NAME, + ), + cluster_name=cluster, + ) diff --git a/src/fairchem/core/launchers/cluster/ray_cluster.py b/src/fairchem/core/launchers/cluster/ray_cluster.py new file mode 100644 index 0000000000..99a78b11b0 --- /dev/null +++ b/src/fairchem/core/launchers/cluster/ray_cluster.py @@ -0,0 +1,413 @@ +# ruff: noqa +from __future__ import annotations + +import dataclasses +import json +import os +import shutil +import socket +import subprocess +import tempfile +import time +from typing import Callable, Optional, TypeVar +import uuid +from contextlib import closing +from pathlib import Path + +import psutil +import submitit + + +def kill_proc_tree(pid, including_parent=True): + parent = psutil.Process(pid) + children = parent.children(recursive=True) + for child in children: + child.kill() + psutil.wait_procs(children, timeout=5) + if including_parent: + parent.kill() + parent.wait(5) + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def scancel(job_ids: list[str]): + """ + Cancel the SLURM jobs with the given job IDs. + + This function takes a list of job IDs. + + Args: + job_ids (List[str]): A list of job IDs to cancel. + """ + root_ids = list(set([i.split("_", maxsplit=2)[0] for i in job_ids])) + subprocess.check_call(["scancel"] + root_ids) + + +start_ip_pattern = r"ray start --address='([0-9\.]+):([0-9]+)'" + +PayloadReturnT = TypeVar("PayloadReturnT") + + +def mk_symlinks(target_dir: Path, job_type: str, paths: submitit.core.utils.JobPaths): + """Create symlinks for the job's stdout and stderr in the target directory with a nicer name.""" + (target_dir / f"{job_type}.err").symlink_to(paths.stderr) + (target_dir / f"{job_type}.out").symlink_to(paths.stdout) + + +@dataclasses.dataclass +class HeadInfo: + """ + information about the head node that we can share to workers + """ + + hostname: Optional[str] = None + port: Optional[int] = None + temp_dir: Optional[str] = None + + +class RayClusterState: + """ + This class is responsible for managing the state of the Ray cluster. It is useful to keep track + of the head node and the workers, and to make sure they are all ready before starting the payload. + + It relies on storing info in a rendezvous directory so they can be shared async between jobs. + + Args: + rdv_dir (Path): The directory where the rendezvous information will be stored. Defaults to ~/.fairray. + cluster_id (str): A unique identifier for the cluster. Defaults to a random UUID. You only want to set this if you want to connect to an existing cluster. + """ + + def __init__( + self, + rdv_dir: Optional[Path] = None, + cluster_id: Optional[str] = None, + ): + self.rendezvous_rootdir = ( + rdv_dir if rdv_dir is not None else (Path.home() / ".fairray") + ) + self._cluster_id = ( + uuid.uuid4().hex if cluster_id is None else cluster_id + ) # maybe use something more readable + self.jobs_dir.mkdir(parents=True, exist_ok=True) + + @property + def cluster_id(self) -> str: + """Returns the unique identifier for the cluster.""" + return self._cluster_id + + @property + def rendezvous_dir(self) -> Path: + """Returns the path to the directory where the rendezvous information is stored.""" + return self.rendezvous_rootdir / self.cluster_id + + @property + def jobs_dir(self) -> Path: + """Returns the path to the directory where job information is stored.""" + return self.rendezvous_dir / "jobs" + + @property + def _head_json(self) -> Path: + """Returns the path to the JSON file containing head node information.""" + return self.rendezvous_dir / "head.json" + + def is_head_ready(self) -> bool: + """Checks if the head node information is available and ready.""" + return self._head_json.exists() + + def head_info(self) -> Optional[HeadInfo]: + """ + Retrieves the head node information from the stored JSON file. + + Returns: + Optional[HeadInfo]: The head node information if available, otherwise None. + """ + try: + with self._head_json.open("r") as f: + return HeadInfo(**json.load(f)) + except Exception as ex: + print(f"failed to load head info: {ex}. Maybe it's not ready yet?") + return None + + def save_head_info(self, head_info: HeadInfo): + """ + Saves the head node information to a JSON file. + + Args: + head_info (HeadInfo): The head node information to save. + """ + with self._head_json.open("w") as f: + json.dump(dataclasses.asdict(head_info), f) + + def clean(self): + """Removes the rendezvous directory and all its contents.""" + shutil.rmtree(self.rendezvous_dir) + + def add_job(self, job: submitit.Job): + """ + Adds a job to the jobs directory by creating a JSON file with the job's information. + + Args: + job (submitit.Job): The job to add. + """ + with (self.jobs_dir / f"{job.job_id}.json").open("w") as f: + json.dump( + { + "job_id": job.job_id, + }, + fp=f, + ) + + def list_job_ids(self) -> list[str]: + """Lists all job IDs stored in the jobs directory.""" + return [f.stem for f in self.jobs_dir.iterdir()] + + +def _ray_head_script( + cluster_state: RayClusterState, + worker_wait_timeout_seconds: int, + payload: Optional[Callable[..., PayloadReturnT]] = None, + **kwargs, +): + """Start the head node of the Ray cluster on slurm.""" + hostname = socket.gethostname() + head_env = os.environ.copy() + num_cpus = os.environ.get("SLURM_CPUS_ON_NODE", 1) + num_gpus = os.environ.get("SLURM_GPUS_ON_NODE", 0) + # using 0 as the port for the head will make ray search for an open port instead of + # always using the same one. + port = find_free_port() + head_env["RAY_ADDRESS"] = f"{hostname}:{port}" + head_env["RAY_gcs_server_request_timeout_seconds"] = str( + worker_wait_timeout_seconds + ) + print(f"host {hostname}:{port}") + with tempfile.TemporaryDirectory(dir="/tmp") as temp_dir: + # ray workers have the same tempdir name (even on a different host) + # as the head. This is a problem when we use /scratch/slurm_tmpdir/JOBID as + # the tempdir of the head job will not be accessible/visible from other workers if they + # are scheduled on the same host. We are forced to use a different tempdir than /scratch + # TODO ideally, we would still have a /scratch dir that everyone can share. + process = subprocess.Popen( + [ + "ray", + "start", + "--head", + f"--port={port}", + f"--temp-dir={temp_dir}", + "--num-cpus", + f"{num_cpus}", + "--num-gpus", + f"{num_gpus}", + "--dashboard-host=0.0.0.0", + ], + env=head_env, + stdout=subprocess.PIPE, + text=True, + ) + started = False + for line in process.stdout: + if "ray start --address=" in line: + # this is a bit flaky, we search the stdout of the head job to + # find this specific message and extract the address, it might be + # better to not rely on ray printing this as it might change outside of our control. + # Search for the pattern + started = True + assert ( + started + ), "couldn't find head address in stdout. Check head.err for details" + print(f"Head started, ip: {hostname}:{port} ({cluster_state.cluster_id})") + info = HeadInfo(hostname=hostname, port=int(port), temp_dir=temp_dir) + cluster_state.save_head_info(info) + os.environ.update(head_env) + if payload is not None: + payload(**kwargs) + else: + while True: + # practically, we should wait from driver signal to die here + time.sleep(60) + + +def worker_script( + cluster_state: RayClusterState, + worker_wait_timeout_seconds: int, + start_wait_time_seconds: int = 60, # TODO pass this around properly +): + """start an array of worker nodes for the Ray cluster on slurm. Waiting on the head node first.""" + print(f"Waiting for head node. {cluster_state.cluster_id}") + while not cluster_state.is_head_ready(): + # wait for head to have started + time.sleep(5) + print("Head node found.") + head_info = cluster_state.head_info() + assert head_info is not None, "something went wrong getting head information." + worker_env = os.environ.copy() + worker_env["RAY_ADDRESS"] = f"{head_info.hostname}:{head_info.port}" + worker_env["RAY_gcs_server_request_timeout_seconds"] = str( + worker_wait_timeout_seconds + ) + worker_env["RAY_raylet_start_wait_time_s"] = str(start_wait_time_seconds) + num_cpus = os.environ.get("SLURM_CPUS_ON_NODE", 1) + num_gpus = os.environ.get("SLURM_GPUS_ON_NODE", 0) + + try: + subprocess.run( + [ + "ray", + "start", + "--address", + "auto", + "--block", + "--num-cpus", + f"{num_cpus}", + "--num-gpus", + f"{num_gpus}", + ], + env=worker_env, + check=False, + ) + finally: + if head_info.temp_dir: + shutil.rmtree(Path(head_info.temp_dir)) + + +# TODO deal with ports better: https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html#slurm-networking-caveats +# TODO: reqs are just dicts, maybe we want to be more specific (in particular for qos/partition) +# TODO: need better naming too +# TODO: better log messages +# TODO checkpointing to recover worker nodes after timeout/preemption https://github.com/facebookincubator/submitit/blob/main/docs/checkpointing.md +# TODO have a ray autoscaler nodeprovider based on this, e.g. https://github.com/TingkaiLiu/Ray-SLURM-autoscaler/blob/main/slurm/node_provider.py +class RayCluster: + """ + A RayCluster offers tools to start a Ray cluster (head and wokers) on slurm with the correct settings. + + args: + + log_dir: Path to the directory where logs will be stored. Defaults to "raycluster_logs" in the working directory. All slurm logs will go there, + and it also creates symlinks to the stdout/stderr of each jobs with nicer name (head, worker_0, worker_1, ..., driver_0, etc). There interesting + logs will be in the driver_N.err file, you should tail that. + rdv_dir: Path to the directory where the rendezvous information will be stored. Defaults to ~/.fairray. Useful if you are trying to recover an existing cluster. + cluster_id: A unique identifier for the cluster. Defaults to a random UUID. You only want to set this if you want to connect to an existing cluster. + worker_wait_timeout_seconds (int): The number of seconds ray will wait for a worker to be ready before giving up. Defaults to 60 seconds. If you are scheduling + workers in a queue that takes time for allocation, you might want to increase this otherwise your ray payload will fail, not finding resources. + + """ + + log_dir: Path + state: RayClusterState + + jobs: list[submitit.Job] = [] + is_shutdown = False + num_worker_groups = 0 + num_drivers = 0 + head_started = False + + # keeping this in a separate object so it's easy to serialize and pass to jobs + + def __init__( + self, + log_dir: Path = Path("raycluster_logs"), + rdv_dir: Optional[Path] = None, + cluster_id: Optional[str] = None, + worker_wait_timeout_seconds: int = 60, + ): + self.state = RayClusterState(rdv_dir, cluster_id) + print(f"cluster {self.state.cluster_id}") + self.log_dir = Path(log_dir) / self.state.cluster_id + self.state.rendezvous_dir.mkdir(parents=True, exist_ok=True) + self.worker_wait_timeout_seconds = worker_wait_timeout_seconds + print(f"logs will be in {self.log_dir.resolve()}") + + def start_head( + self, + requirements: dict[str, int | str], + executor: str = "slurm", + payload: Optional[Callable[..., PayloadReturnT]] = None, + **kwargs, + ) -> str: + """ + Start the head node of the Ray cluster on slurm. You should do this first. Interesting requirements: qos, partition, time, gpus, cpus-per-task, mem-per-gpu, etc. + """ + assert not self.head_started, "head already started" + # start the head node + self.head_started = True + s_executor = submitit.AutoExecutor( + folder=str(self.log_dir), + cluster=executor, + ) + s_executor.update_parameters( + name=f"ray_head_{self.state.cluster_id}", # TODO name should probably include more details (cluster_id) + **requirements, + ) + head_job = s_executor.submit( + _ray_head_script, + self.state, + self.worker_wait_timeout_seconds, + payload, + **kwargs, + ) + self.state.add_job(head_job) + mk_symlinks(self.log_dir, "head", head_job.paths) + print("head slurm job id:", head_job.job_id) + return head_job.job_id + + def start_workers( + self, + num_workers: int, + requirements: dict[str, int | str], + executor: str = "slurm", + ) -> list[str]: + """ + Start an array of worker nodes of the Ray cluster on slurm. You should do this after starting a head. + Interesting requirements: qos, partition, time, gpus, cpus-per-task, mem-per-gpu, etc. + You can call this multiple times to start an heterogeneous cluster. + """ + # start the workers + s_executor = submitit.AutoExecutor(folder=str(self.log_dir), cluster=executor) + s_executor.update_parameters( + name=f"ray_worker_{self.num_worker_groups}_{self.state.cluster_id}", # TODO name should probably include more details (cluster_id) + **requirements, + ) + + jobs = [] + with s_executor.batch(): # TODO set slurm array max parallelism here, because we really want all jobs to be scheduled at the same time + for i in range(num_workers): + jobs.append( + s_executor.submit( + worker_script, + self.state, + self.worker_wait_timeout_seconds, + ) + ) + + for idx, j in enumerate(jobs): + mk_symlinks(self.log_dir, f"worker_{self.num_worker_groups}_{idx}", j.paths) + print("workers slurm job ids:", [job.job_id for job in jobs]) + for j in jobs: + self.state.add_job(j) + self.num_worker_groups += 1 + return [job.job_id for job in jobs] + + def shutdown(self): + """ + Cancel all slurms jobs and get rid of rdv directory. + """ + self.is_shutdown = True + scancel(self.state.list_job_ids()) + kill_proc_tree( + os.getpid(), including_parent=False + ) # kill local job started by submitit as subprocess TODO that's not going to work when this is not the main process (e.g. recovering on cli) + self.state.clean() + print(f"cluster {self.state.cluster_id} shutdown") + + def __enter__(self): + # only use as a context if you have something blocking waiting on the driver + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.shutdown() diff --git a/src/fairchem/core/launchers/ray_on_slurm_launch.py b/src/fairchem/core/launchers/ray_on_slurm_launch.py new file mode 100644 index 0000000000..f0e39a3b65 --- /dev/null +++ b/src/fairchem/core/launchers/ray_on_slurm_launch.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING + +import hydra +import ray +import torch.distributed as dist +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.distributed.elastic.utils.distributed import get_free_port + +from fairchem.core.common import gp_utils +from fairchem.core.common.distutils import ( + assign_device_for_local_rank, + setup_env_local_multi_gpu, +) +from fairchem.core.common.utils import setup_env_vars +from fairchem.core.components.runner import Runner +from fairchem.core.launchers.api import DeviceType +from fairchem.core.launchers.cluster.ray_cluster import RayCluster + +if TYPE_CHECKING: + from omegaconf import DictConfig + + from fairchem.core.launchers.api import SchedulerConfig, SlurmConfig + + +@ray.remote +class SPMDWorker: + def __init__( + self, + job_config: DictConfig, + runner_config: DictConfig, + worker_id: int, + world_size: int, + device: str, + gp_size: int | None = None, + master_addr: str | None = None, + master_port: int | None = None, + ): + self.runner_config = runner_config + # master address and port is not passed in, initialize it here + self.master_address = ( + ray.util.get_node_ip_address() if master_addr is None else master_addr + ) + self.master_port = get_free_port() if master_port is None else master_port + self.worker_id = worker_id + self.device = device + self.gp_size = gp_size + self.world_size = world_size + self.job_config = job_config + setup_env_vars() + self.distributed_setup = False + + def _distributed_setup( + self, + worker_id: int, + world_size: int, + master_address: str, + master_port: int, + device: str, + gp_size: int | None, + ): + setup_env_local_multi_gpu(worker_id, master_port, master_address) + assign_device_for_local_rank(device == "cpu", 0) + backend = "gloo" if device == "cpu" else "nccl" + dist.init_process_group( + backend=backend, + rank=worker_id, + world_size=world_size, + ) + if gp_size is not None: + gp_utils.setup_graph_parallel_groups(gp_size, backend) + + def get_master_address_and_port(self): + return (self.master_address, self.master_port) + + def run(self): + if not self.distributed_setup: + # initialize distributed environment + self._distributed_setup( + worker_id=self.worker_id, + world_size=self.world_size, + master_address=self.master_address, + master_port=self.master_port, + device=self.device, + gp_size=self.gp_size, + ) + self.runner: Runner = hydra.utils.instantiate(self.runner_config) + self.runner.job_config = self.job_config + self.distributed_setup = True + self.runner.run() + + +class SPMDController(Runner): + # this is equivalent to the fairchem SlurmSPMDProgram routine that runs the runner on every worker + def __init__(self, job_config: DictConfig, runner_config: DictConfig): + self.job_config = job_config + self.runner_config = runner_config + self.device = job_config.device_type.value + self.world_size = ( + job_config.scheduler.num_nodes * job_config.scheduler.ranks_per_node + ) + self.gp_group_size = job_config.graph_parallel_group_size + self.ranks_per_node = job_config.scheduler.ranks_per_node + self.num_nodes = job_config.scheduler.num_nodes + num_gpus_per_group = ( + self.ranks_per_node if job_config.device_type == DeviceType.CUDA else 0 + ) + bundle_gpus = { + "GPU": num_gpus_per_group, + "CPU": self.ranks_per_node, + } + placement_groups = [] + # first create one placement group for each node + for _ in range(self.num_nodes): + pg = ray.util.placement_group([bundle_gpus], strategy="STRICT_PACK") + placement_groups.append(pg) + ray.get(pg.ready()) # Wait for each placement group to be scheduled + + logging.info(f"{len(placement_groups)} placement groups are ready") + rank0_worker = SPMDWorker.options( + num_gpus=1 if num_gpus_per_group > 0 else 0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=placement_groups[0], + placement_group_bundle_index=0, # Use the first (and only) bundle in the PG + placement_group_capture_child_tasks=True, # Ensure child tasks also run in this PG + ), + ).remote( + self.job_config, + self.runner_config, + 0, + self.world_size, + self.device, + self.gp_group_size, + None, + None, + ) + master_addr, master_port = ray.get( + rank0_worker.get_master_address_and_port.remote() + ) + logging.info(f"Started rank0 on {master_addr}:{master_port}") + self.workers = [rank0_worker] + + # next place all ranks in order and pack them on placement groups + # ie: rank0-7 -> placement group 0, 8->15 -> placement group 1 etc. + for pg_idx, pg in enumerate(placement_groups): + print(f"Launching workers for placement group {pg_idx} (Node {pg_idx})") + + for gpu_rank_on_node in range(self.ranks_per_node): + if pg_idx == 0 and gpu_rank_on_node == 0: + continue + # Each actor requests 1 GPU and uses the specific placement group + actor = SPMDWorker.options( + num_gpus=1 if num_gpus_per_group > 0 else 0, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, # Use the first (and only) bundle in the PG + placement_group_capture_child_tasks=True, # Ensure child tasks also run in this PG + ), + ).remote( + self.job_config, + self.runner_config, + pg_idx * self.ranks_per_node + gpu_rank_on_node, + self.world_size, + self.device, + self.gp_group_size, + master_addr, + master_port, + ) + self.workers.append(actor) + + def run(self): + logging.info("Running SPMDWrapper payload ...") + futures = [w.run.remote() for w in self.workers] + ray.get(futures) + + def save_state(self, checkpoint_location: str, is_preemption: bool = False) -> bool: + pass + + def load_state(self, checkpoint_location: str | None) -> None: + pass + + +def ray_entrypoint(runner_config: DictConfig): + runner = hydra.utils.instantiate(runner_config, _recursive_=False) + runner.run() + + +def ray_on_slurm_launch(config: DictConfig, log_dir: str): + scheduler_config: SchedulerConfig = config.job.scheduler + slurm_config: SlurmConfig = scheduler_config.slurm + cluster = RayCluster(log_dir=Path(log_dir)) + cluster_reqs = { + "slurm_account": slurm_config.account, + "slurm_qos": slurm_config.qos, + "timeout_min": slurm_config.timeout_hr * 60, + "mem_gb": slurm_config.mem_gb, + } + worker_nodes = scheduler_config.num_nodes - 1 + + all_job_ids = [] + head_job_id = cluster.start_head( + requirements=cluster_reqs + | { + "nodes": 1, + "cpus_per_task": slurm_config.cpus_per_task + * scheduler_config.ranks_per_node, + "gpus_per_task": scheduler_config.ranks_per_node, + "tasks_per_node": 1, + }, + executor="slurm", + payload=ray_entrypoint, + runner_config=config.runner, + ) + all_job_ids.append(head_job_id) + logging.info("Ray head started") + + if worker_nodes > 0: + worker_ids = cluster.start_workers( + 1, + requirements=cluster_reqs + | { + "nodes": worker_nodes, + "gpus_per_task": scheduler_config.ranks_per_node, + "cpus_per_task": slurm_config.cpus_per_task + * scheduler_config.ranks_per_node, + "tasks_per_node": 1, + }, + ) + all_job_ids.extend(worker_ids) + logging.info("Ray workers started") + + logging.info(f"To cancel: scancel {' '.join(all_job_ids)}") diff --git a/src/fairchem/core/launchers/slurm_launch.py b/src/fairchem/core/launchers/slurm_launch.py new file mode 100644 index 0000000000..a3f195df43 --- /dev/null +++ b/src/fairchem/core/launchers/slurm_launch.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import logging +import os +import random +from typing import TYPE_CHECKING + +import hydra +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf +from submitit import AutoExecutor +from submitit.core.utils import JobPaths, cloudpickle_dump +from submitit.helpers import Checkpointable, DelayedSubmission +from submitit.slurm.slurm import SlurmJobEnvironment + +from fairchem.core.common import distutils +from fairchem.core.common.gp_utils import setup_graph_parallel_groups +from fairchem.core.common.logger import WandBSingletonLogger +from fairchem.core.common.utils import ( + setup_env_vars, + setup_logging, +) +from fairchem.core.launchers.api import ( + DeviceType, + JobConfig, + RunType, + SchedulerType, + SlurmEnv, +) + +if TYPE_CHECKING: + from fairchem.core.components.reducer import Reducer + from fairchem.core.components.runner import Runner + + +def _get_slurm_env() -> SlurmEnv: + slurm_job_env = SlurmJobEnvironment() + try: + slurm_env = SlurmEnv( + job_id=slurm_job_env.job_id, + raw_job_id=slurm_job_env.raw_job_id, + array_job_id=slurm_job_env.array_job_id, + array_task_id=slurm_job_env.array_task_id, + restart_count=os.environ.get("SLURM_RESTART_COUNT"), + ) + except KeyError: + # slurm environment variables are undefined, running locally + slurm_env = SlurmEnv() + + return slurm_env + + +def map_job_config_to_dist_config(job_cfg: JobConfig) -> dict: + scheduler_config = job_cfg.scheduler + return { + "world_size": scheduler_config.num_nodes * scheduler_config.ranks_per_node, + "distributed_backend": ( + "gloo" if job_cfg.device_type == DeviceType.CPU else "nccl" + ), + "submit": scheduler_config.mode == SchedulerType.SLURM, + "cpu": job_cfg.device_type == DeviceType.CPU, + "init_method": scheduler_config.distributed_init_method, + # for distributed shared file initialization + "shared_file_dir": os.path.join(job_cfg.run_dir, job_cfg.timestamp_id), + "array_job_num": job_cfg.metadata.array_job_num, + } + + +def remove_runner_state_from_submission(log_folder: str, job_id: str) -> None: + # (HACK) Decouple the job from the runner state by manually modifying it + # this ensures the saved runner state is not re-submitted in the event of a node failure + # ie: if the job was started at state t=T, a requeue during node failure would resubmit the job + # starting at state t=T again without calling the checkpoint callback, losing all progress in between. + job_path = JobPaths(folder=log_folder, job_id=job_id) + if os.path.isfile(job_path.submitted_pickle): + submission_obj = DelayedSubmission.load(job_path.submitted_pickle) + submission_obj.args[0].job.runner_state_path = None + cloudpickle_dump(submission_obj, job_path.submitted_pickle) + + +def runner_wrapper(config: DictConfig, run_type: RunType = RunType.RUN): + # This is needed when using elastic_launch for local runs since it looks for + # the __name__ attribute of the function, Submitit.__call__ does not have one + SlurmSPMDProgram()(config, run_type) + + +def _set_seeds(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _set_deterministic_mode() -> None: + # this is required for full cuda deterministic mode + logging.info("Setting deterministic mode!") + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + +class SlurmSPMDProgram(Checkpointable): + """ + Entrypoint for a SPMD program launched via submitit on slurm. + This assumes all ranks run the identical copy of this code + """ + + def __init__(self) -> None: + self.config = None + self.runner = None + self.reducer = None + + def __call__( + self, dict_config: DictConfig, run_type: RunType = RunType.RUN + ) -> None: + self.config = dict_config + self.run_type = run_type + # modify the config metadata to add slurm info if they exist + self.config.job.metadata.slurm_env = _get_slurm_env() + + setup_env_vars() + setup_logging() + + dist_config = map_job_config_to_dist_config(self.config.job) + logging.info("Setting up distributed backend...") + distutils.setup(dist_config) + distutils.synchronize() + if ( + distutils.is_master() + and self.config.job.scheduler.mode == SchedulerType.SLURM + ): + # this pickle file is shared across all processes so can only modify this on the main rank + remove_runner_state_from_submission( + dict_config.job.metadata.log_dir, + self.config.job.metadata.slurm_env.job_id, + ) + + if self.config.job.graph_parallel_group_size is not None: + logging.info("Setting up graph parallel...") + setup_graph_parallel_groups( + self.config.job.graph_parallel_group_size, + dist_config["distributed_backend"], + ) + + self._init_logger() + + _set_seeds(self.config.job.seed) + if self.config.job.deterministic: + _set_deterministic_mode() + + if run_type == RunType.RUN: + logging.info("Calling runner.run() ...") + self.runner: Runner = hydra.utils.instantiate(self.config.runner) + self.runner.job_config = self.config.job + # must call resume state AFTER the runner has been initialized + self.runner.load_state(self.config.job.runner_state_path) + self.runner.run() + elif run_type == RunType.REDUCE: + logging.info("Calling reducer.reduce() ...") + self.reducer: Reducer = hydra.utils.instantiate(self.config.reducer) + self.reducer.job_config = self.config.job + self.reducer.runner_config = self.config.runner + # must call resume state AFTER the runner has been initialized + self.reducer.load_state(self.config.job.runner_state_path) + self.reducer.reduce() + else: + raise ValueError(f"run type {run_type} is not recognized!") + + distutils.cleanup() + + def _init_logger(self) -> None: + if ( + self.config.job.logger + and distutils.is_master() + and not self.config.job.debug + and self.config.job.metadata.array_job_num == 0 + ): + # get a partial function from the config and instantiate wandb with it + # currently code assumes that we only use the WandBSingletonLogger + logger_initializer = hydra.utils.instantiate(self.config.job.logger) + simple_config = OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=True + ) + logger_initializer( + config=simple_config, + run_id=self.config.job.timestamp_id, + run_name=self.config.job.run_name, + log_dir=self.config.job.metadata.log_dir, + ) + + def checkpoint(self, *args, **kwargs) -> DelayedSubmission: + logging.error("Submitit checkpointing callback is triggered") + save_path = self.config.job.metadata.preemption_checkpoint_dir + cfg_copy = self.config.copy() + # only assign if the save was successful + cfg_copy.job.runner_state_path = None + + if ( + self.run_type == RunType.RUN + and self.runner.save_state(save_path, is_preemption=True) + ) or ( + self.run_type == RunType.REDUCE + and self.reducer.save_state(save_path, is_preemption=True) + ): + cfg_copy.job.runner_state_path = save_path + + if WandBSingletonLogger.initialized(): + WandBSingletonLogger.get_instance().mark_preempting() + logging.info( + f"Submitit checkpointing callback is completed, resuming with use the following state: {save_path}" + ) + return DelayedSubmission(SlurmSPMDProgram(), cfg_copy) + + +def slurm_launch(cfg: DictConfig, log_dir: str) -> None: + scheduler_cfg = cfg.job.scheduler + executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) + executor.update_parameters( + name=cfg.job.run_name, + mem_gb=scheduler_cfg.slurm.mem_gb, + timeout_min=scheduler_cfg.slurm.timeout_hr * 60, + slurm_partition=scheduler_cfg.slurm.partition, + gpus_per_node=scheduler_cfg.ranks_per_node, + cpus_per_task=scheduler_cfg.slurm.cpus_per_task, + tasks_per_node=scheduler_cfg.ranks_per_node, + nodes=scheduler_cfg.num_nodes, + slurm_qos=scheduler_cfg.slurm.qos, + slurm_account=scheduler_cfg.slurm.account, + slurm_additional_parameters=scheduler_cfg.slurm.additional_parameters, + ) + if scheduler_cfg.num_array_jobs == 1: + job = executor.submit(SlurmSPMDProgram(), cfg) + logging.info( + f"Submitted job id: {cfg.job.timestamp_id}, slurm id: {job.job_id}, logs: {cfg.job.metadata.log_dir}" + ) + jobs = [job] + elif scheduler_cfg.num_array_jobs > 1: + executor.update_parameters( + slurm_array_parallelism=scheduler_cfg.num_array_jobs, + ) + + jobs = [] + with executor.batch(): + for job_number in range(scheduler_cfg.num_array_jobs): + _cfg = cfg.copy() + _cfg.job.metadata.array_job_num = job_number + job = executor.submit(SlurmSPMDProgram(), _cfg) + jobs.append(job) + logging.info(f"Submitted {len(jobs)} jobs: {jobs[0].job_id.split('_')[0]}") + + if "reducer" in cfg: + job_id = jobs[0].job_id.split("_")[0] + executor.update_parameters( + name=f"{cfg.job.run_name}_reduce", + # set a single node, or do we want the same config as the Runner or a separate JobConfig + nodes=1, + slurm_dependency=f"afterok:{job_id}", + slurm_additional_parameters={ + "kill-on-invalid-dep": "yes" + }, # kill the reducer if run fails + ) + executor.submit(SlurmSPMDProgram(), cfg, RunType.REDUCE) + + +def local_launch(cfg: DictConfig, log_dir: str): + """ + Launch locally with torch elastic (for >1 workers) or just single process + """ + scheduler_cfg = cfg.job.scheduler + if scheduler_cfg.ranks_per_node > 1: + from torch.distributed.launcher.api import LaunchConfig, elastic_launch + + launch_config = LaunchConfig( + min_nodes=1, + max_nodes=1, + nproc_per_node=scheduler_cfg.ranks_per_node, + rdzv_backend="c10d", + max_restarts=0, + ) + elastic_launch(launch_config, runner_wrapper)(cfg) + if "reducer" in cfg: + elastic_launch(launch_config, runner_wrapper)(cfg, RunType.REDUCE) + else: + logging.info("Running in local mode without elastic launch") + distutils.setup_env_local() + runner_wrapper(cfg) + if "reducer" in cfg: + runner_wrapper(cfg, RunType.REDUCE) diff --git a/tests/core/conftest.py b/tests/core/conftest.py index 91f39dc9ae..6c5f294e87 100644 --- a/tests/core/conftest.py +++ b/tests/core/conftest.py @@ -6,21 +6,10 @@ """ from __future__ import annotations -from fairchem.core.units.mlip_unit.mlip_unit import ( - UNIT_INFERENCE_CHECKPOINT, - UNIT_RESUME_CONFIG, -) - -from tests.core.units.mlip_unit.create_fake_dataset import ( - create_fake_uma_dataset, -) import os import tempfile - -from tests.core.testing_utils import launch_main from itertools import product -import logging from random import choice from typing import TYPE_CHECKING @@ -33,6 +22,14 @@ from syrupy.extensions.amber import AmberSnapshotExtension from fairchem.core.datasets import AseDBDataset +from fairchem.core.units.mlip_unit.mlip_unit import ( + UNIT_INFERENCE_CHECKPOINT, + UNIT_RESUME_CONFIG, +) +from tests.core.testing_utils import launch_main +from tests.core.units.mlip_unit.create_fake_dataset import ( + create_fake_uma_dataset, +) if TYPE_CHECKING: from syrupy.types import SerializableData @@ -222,7 +219,7 @@ def dummy_binary_dataset(dummy_binary_dataset_path): def run_around_tests(): # If debugging GPU memory issues, uncomment this print statement # to get full GPU memory allocations before each test runs - #print(torch.cuda.memory_summary()) + # print(torch.cuda.memory_summary()) yield torch.cuda.empty_cache() @@ -343,7 +340,6 @@ def conserving_mole_checkpoint(fake_uma_dataset): return inference_checkpoint_pt, checkpoint_state_yaml - @pytest.fixture(scope="session") def fake_uma_dataset(): with tempfile.TemporaryDirectory() as tempdirname: diff --git a/tests/core/test_cli.py b/tests/core/test_cli.py index d23e94be22..80d267a797 100644 --- a/tests/core/test_cli.py +++ b/tests/core/test_cli.py @@ -90,3 +90,16 @@ def get_cfg_from_yaml(): assert cfg.job.run_name is not None assert cfg.job.seed is not None assert cfg.keys() == ALLOWED_TOP_LEVEL_KEYS + + +@pytest.mark.parametrize("num_ranks", [1, 4]) +def test_cli_ray(num_ranks): + distutils.cleanup() + hydra.core.global_hydra.GlobalHydra.instance().clear() + sys_args = [ + "--config", + "tests/core/test_ray_runner.yml", + f"job.scheduler.ranks_per_node={num_ranks}", + ] + sys.argv[1:] = sys_args + main() diff --git a/tests/core/test_ray_runner.yml b/tests/core/test_ray_runner.yml new file mode 100644 index 0000000000..2cabf70bef --- /dev/null +++ b/tests/core/test_ray_runner.yml @@ -0,0 +1,17 @@ +job: + device_type: CPU + scheduler: + use_ray: true + mode: LOCAL + ranks_per_node: 1 + num_nodes: 1 + run_name: test_ray_runner + +runner: + _target_: fairchem.core.launchers.ray_on_slurm_launch.SPMDController + job_config: ${job} + runner_config: + _target_: fairchem.core.components.runner.MockRunner + x: 10 + y: 23 + z: 1