Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions configs/uma/training_release/uma_ray_demo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# example yaml of using the Ray on slurm launcher to run the UMA training job
# Logger and checkpointing is current not enabled

defaults:
- cluster: h100
- backbone: K4L2
- dataset: uma
- element_refs: uma_v1_hof_lin_refs
- tasks: uma_direct
- _self_

job:
device_type: ${cluster.device}
scheduler:
use_ray: true
mode: ${cluster.mode}
ranks_per_node: ${cluster.ranks_per_node}
num_nodes: 1
slurm:
account: ${cluster.account}
qos: ${cluster.qos}
mem_gb: ${cluster.mem_gb}
cpus_per_task: ${cluster.cpus_per_task}
debug: ${cluster.debug}
run_dir: ${cluster.run_dir}
run_name: uma_sm_direct
# logger:
# _target_: fairchem.core.common.logger.WandBSingletonLogger.init_wandb
# _partial_: true
# entity: fairchem
# project: uma

moe_layer_type: pytorch
num_moe_experts: 32
max_neighbors: 30
cutoff_radius: 6
epochs: null
steps: 1680000 # 140B atoms, 128 ranks, max atoms 700 (mean atoms 650)
max_atoms: 700
bf16: True
cpu_graph: True
otf_graph: False
normalizer_rmsd: 1.423
direct_forces_coef: 30
omc_energy_coef: 10
omol_energy_coef: 30
odac_energy_coef: 10
oc20_energy_coef: 10
omat_energy_coef: 10

regress_stress: False
direct_forces: True

oc20_forces_key: forces
omat_forces_key: forces
omol_forces_key: forces
odac_forces_key: forces
omc_forces_key: forces

dataset_list: ["oc20", "omol", "omat", "odac", "omc"]

exclude_keys: [
"id", # only oc20,oc22 have this
"fid", # only oc20,oc22 have this
"absolute_idx", # only ani has this
"target_pos", # only ani has this
"ref_energy", # only ani/geom have this
"pbc", # only ani/transition1x have this
"nads", # oc22
"oc22", # oc22
"formation_energy", # spice
"total_charge", # spice
]

train_dataset:
_target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset
dataset_configs:
omc: ${dataset.omc_train}
omol: ${dataset.omol_train}
odac: ${dataset.odac_train}
omat: ${dataset.omat_train}
oc20: ${dataset.oc20_train}
combined_dataset_config:
sampling:
type: explicit
ratios:
omol.train: 4.0
oc20.train: 1.0
omc.train: 2.0
odac.train: 1.0
omat.train: 2.0

val_dataset:
_target_: fairchem.core.datasets.mt_concat_dataset.create_concat_dataset
dataset_configs:
omc: ${dataset.omc_val}
omol: ${dataset.omol_val}
odac: ${dataset.odac_val}
omat: ${dataset.omat_val}
oc20: ${dataset.oc20_val}
combined_dataset_config: { sampling: {type: temperature, temperature: 1.0} }

train_dataloader:
_target_: fairchem.core.components.common.dataloader_builder.get_dataloader
dataset: ${train_dataset}
batch_sampler_fn:
_target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler
_partial_: True
max_atoms: ${max_atoms}
shuffle: True
seed: 0
num_workers: ${cluster.dataloader_workers}
collate_fn:
_target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter
tasks: ${tasks}
exclude_keys: ${exclude_keys}

eval_dataloader:
_target_: fairchem.core.components.common.dataloader_builder.get_dataloader
dataset: ${val_dataset}
batch_sampler_fn:
_target_: fairchem.core.datasets.samplers.max_atom_distributed_sampler.MaxAtomDistributedBatchSampler
_partial_: True
max_atoms: ${max_atoms}
shuffle: False
seed: 0
num_workers: ${cluster.dataloader_workers}
collate_fn:
_target_: fairchem.core.units.mlip_unit.mlip_unit.mt_collater_adapter
tasks: ${tasks}
exclude_keys: ${exclude_keys}

heads:
oc20_energy:
module: fairchem.core.models.uma.escn_md.MLP_Energy_Head
omat_energy:
module: fairchem.core.models.uma.escn_md.MLP_Energy_Head
omc_energy:
module: fairchem.core.models.uma.escn_md.MLP_Energy_Head
omol_energy:
module: fairchem.core.models.uma.escn_md.MLP_Energy_Head
odac_energy:
module: fairchem.core.models.uma.escn_md.MLP_Energy_Head
forces:
module: fairchem.core.models.uma.escn_md.Linear_Force_Head

runner:
_target_: fairchem.core.launchers.ray_on_slurm_launch.SPMDController
job_config: ${job}
runner_config:
_target_: fairchem.core.components.train.train_runner.TrainEvalRunner
train_dataloader: ${train_dataloader}
eval_dataloader: ${eval_dataloader}
train_eval_unit:
_target_: fairchem.core.units.mlip_unit.mlip_unit.MLIPTrainEvalUnit
job_config: ${job}
tasks: ${tasks}
model:
_target_: fairchem.core.models.base.HydraModel
backbone: ${backbone}
heads: ${heads}
optimizer_fn:
_target_: torch.optim.AdamW
_partial_: true
lr: 8e-4
weight_decay: 1e-3
cosine_lr_scheduler_fn:
_target_: fairchem.core.units.mlip_unit.mlip_unit._get_consine_lr_scheduler
_partial_: true
warmup_factor: 0.2
warmup_epochs: 0.01
lr_min_factor: 0.01
epochs: ${epochs}
steps: ${steps}
print_every: 10
clip_grad_norm: 100
bf16: ${bf16}
max_epochs: ${epochs}
max_steps: ${steps}
evaluate_every_n_steps: 10000
callbacks:
- _target_: fairchem.core.common.profiler_utils.ProfilerCallback
job_config: ${job}
# - _target_: fairchem.core.components.train.train_runner.TrainCheckpointCallback
# checkpoint_every_n_steps: 5000
# max_saved_checkpoints: 5
2 changes: 1 addition & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict", "ipywidgets"]
adsorbml = ["dscribe","x3dase","scikit-image"]
extras = ["ray", "pymatgen", "quacc[phonons]>=0.15.3","pandas"]
extras = ["ray[default]", "pymatgen", "quacc[phonons]>=0.15.3", "pandas"]

[project.scripts]
fairchem = "fairchem.core._cli:main"
Expand Down
Loading