Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 74 additions & 38 deletions src/fairchem/lammps/lammps_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch
from ase.data import atomic_masses, chemical_symbols
from ase.geometry import wrap_positions

from fairchem.core.datasets.atomic_data import AtomicData
from lammps import lammps
Expand All @@ -32,16 +33,23 @@ def check_input_script(input_script: str):

def check_atom_id_match_masses(types_arr, masses):
for atom_id in types_arr:
assert np.allclose(
masses[atom_id], atomic_masses[atom_id], atol=1e-1
), f"Atom {chemical_symbols[atom_id]} (type {atom_id}) has mass {masses[atom_id]} but is expected to have mass {atomic_masses[atom_id]}."
assert np.allclose(masses[atom_id], atomic_masses[atom_id], atol=1e-1), (
f"Atom {chemical_symbols[atom_id]} (type {atom_id}) has mass {masses[atom_id]} but is expected to have mass {atomic_masses[atom_id]}."
)


def atomic_data_from_lammps_data(
x, atomic_numbers, nlocal, cell, periodicity, task_name
x: np.ndarray | torch.Tensor,
atomic_numbers,
nlocal,
cell,
periodicity,
task_name,
charge: int = 0,
spin: int = 0,
):
# TODO: do we need to take of care of wrapping atoms that are outside the cell?
pos = torch.tensor(x, dtype=torch.float32)
pos = torch.as_tensor(x, dtype=torch.float32)
pbc = torch.tensor(periodicity, dtype=torch.bool).unsqueeze(0)
edge_index = torch.empty((2, 0), dtype=torch.long)
cell_offsets = torch.empty((0, 3), dtype=torch.float32)
Expand All @@ -58,8 +66,8 @@ def atomic_data_from_lammps_data(
edge_index=edge_index,
cell_offsets=cell_offsets,
nedges=nedges,
charge=torch.LongTensor([0]),
spin=torch.LongTensor([0]),
charge=torch.LongTensor([charge]),
spin=torch.LongTensor([spin]),
fixed=fixed,
tags=tags,
batch=batch,
Expand Down Expand Up @@ -116,7 +124,7 @@ def lookup_atomic_number_by_mass(mass_arr: np.ndarray | float) -> np.ndarray | i
return atomic_numbers


def separate_run_commands(input_script: str) -> str:
def separate_run_commands(input_script: str) -> tuple[list[str], list[str]]:
lines = input_script.splitlines()
run_cmds = []
script = []
Expand Down Expand Up @@ -145,52 +153,74 @@ def cell_from_lammps_box(boxlo, boxhi, xy, yz, xz):
return unit_cell_matrix.unsqueeze(0)


def fix_external_call_back(lmp, ntimestep, nlocal, tag, x, f):
# force copy here, otherwise we can accident modify the original array in lammps
# TODO: only need to get atomic numbers once and cache it?
# is there a way to check atom types are mapped correctly?
atom_type_np = lmp.numpy.extract_atom("type")
masses = lmp.numpy.extract_atom("mass")
atomic_mass_arr = masses[atom_type_np]
atomic_numbers = lookup_atomic_number_by_mass(atomic_mass_arr)
boxlo, boxhi, xy, yz, xz, periodicity, box_change = lmp.extract_box()
cell = cell_from_lammps_box(boxlo, boxhi, xy, yz, xz)
atomic_data = atomic_data_from_lammps_data(
x, atomic_numbers, nlocal, cell, periodicity, lmp._task_name
)
results = lmp._predictor.predict(atomic_data)
assert "forces" in results, "forces must be in results"
f[:] = results["forces"].cpu().numpy()[:]
lmp.fix_external_set_energy_global(FIX_EXT_ID, results["energy"].item())

# during NPT for example, box_change should be set to 1 by lammps to allow the cell to change
if box_change:
# stress is defined as virial/volume in lammps
assert "stress" in results, "stress must be in results to compute virial"
volume = torch.det(cell).abs().item()
v = (results["stress"].cpu() * volume)[0]
# virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd
virial_arr = [v[0], v[4], v[8], v[1], v[2], v[5]]
lmp.fix_external_set_virial_global(FIX_EXT_ID, virial_arr)
class FixExternalCallback:
def __init__(self, charge: int = 0, spin: int = 0):
self.charge = charge
self.spin = spin

def __call__(self, lmp, ntimestep, nlocal, tag, x, f):
# force copy here, otherwise we can accident modify the original array in lammps
# TODO: only need to get atomic numbers once and cache it?
# is there a way to check atom types are mapped correctly?
atom_type_np = lmp.numpy.extract_atom("type")
masses = lmp.numpy.extract_atom("mass")
atomic_mass_arr = masses[atom_type_np]
atomic_numbers = lookup_atomic_number_by_mass(atomic_mass_arr)
boxlo, boxhi, xy, yz, xz, periodicity, box_change = lmp.extract_box()
cell = cell_from_lammps_box(boxlo, boxhi, xy, yz, xz)

x_wrapped = wrap_positions(
x, cell=cell.squeeze().numpy(), pbc=periodicity, eps=0
)

atomic_data = atomic_data_from_lammps_data(
x_wrapped,
atomic_numbers,
nlocal,
cell,
periodicity,
lmp._task_name,
charge=self.charge,
spin=self.spin,
)
results = lmp._predictor.predict(atomic_data)
assert "forces" in results, "forces must be in results"
f[:] = results["forces"].cpu().numpy()[:]
lmp.fix_external_set_energy_global(FIX_EXT_ID, results["energy"].item())

# during NPT for example, box_change should be set to 1 by lammps to allow the cell to change
if box_change:
# stress is defined as -virial/volume in lammps
assert "stress" in results, "stress must be in results to compute virial"
volume = torch.det(cell).abs().item()
v = (-results["stress"].detach().cpu() * volume)[0].tolist()
# virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd
virial_arr = [v[0], v[4], v[8], v[1], v[2], v[5]]
lmp.fix_external_set_virial_global(FIX_EXT_ID, virial_arr)


def run_lammps_with_fairchem(
predictor: MLIPPredictUnitProtocol, lammps_input_path: str, task_name: str
predictor: MLIPPredictUnitProtocol,
lammps_input_path: str,
task_name: str,
charge: int = 0,
spin: int = 0,
):
machine = None
if "LAMMPS_MACHINE_NAME" in os.environ:
machine = os.environ["LAMMPS_MACHINE_NAME"]
lmp = lammps(name=machine, cmdargs=["-nocite", "-log", "none", "-echo", "screen"])
lmp._predictor = predictor
lmp._task_name = task_name
run_cmds = []
# run_cmds = []
with open(lammps_input_path) as f:
input_script = f.read()
check_input_script(input_script)
script, run_cmds = separate_run_commands(input_script)
logging.info(f"Running input script: {input_script}")
lmp.commands_list(script)
lmp.command(FIX_EXTERNAL_CMD)
fix_external_call_back = FixExternalCallback(charge=charge, spin=spin)
lmp.set_fix_external_callback(FIX_EXT_ID, fix_external_call_back, lmp)
lmp.commands_list(run_cmds)
return lmp
Expand All @@ -203,7 +233,13 @@ def run_lammps_with_fairchem(
)
def main(cfg: DictConfig):
predict_unit = hydra.utils.instantiate(cfg.predict_unit)
lmp = run_lammps_with_fairchem(predict_unit, cfg.lmp_in, cfg.task_name)
lmp = run_lammps_with_fairchem(
predict_unit,
cfg.lmp_in,
cfg.task_name,
cfg.charge,
cfg.spin,
)
# this is required to cleanup the predictor
del lmp._predictor

Expand Down
2 changes: 2 additions & 0 deletions src/fairchem/lammps/lammps_fc_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ parallel_predict_unit:
predict_unit: ${local_predict_unit}
lmp_in: "lammps_in_example.file"
task_name: "omol"
charge: 0
spin: 0
19 changes: 19 additions & 0 deletions tests/lammps/lammps_npt.file
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
units metal # Use metal units (Angstroms, eV, ps)
atom_style atomic # Atoms have a single type and position
lattice fcc 3.567
boundary p p p

region simbox block 0 2 0 2 0 2
create_box 1 simbox
create_atoms 1 region simbox
mass 1 12.011

velocity all create 300.0 12345 dist gaussian # Set initial velocities at 300 K

timestep 0.001 # 1 fs
# Use NPT (isotropic) thermostat+barostat: target temp 300 K, target pressure 0 bar
# fix npt syntax: fix ID group-ID npt temp Tstart Tstop Tdamp iso Pstart Pstop Pdamp
fix 1 all npt temp 300.0 300.0 0.1 iso 0.0 0.0 1.0
thermo_style custom step temp pe ke etotal press vol
thermo 1
run 100
67 changes: 66 additions & 1 deletion tests/lammps/test_ase_vs_lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from ase import units
from ase.build import bulk
from ase.md.langevin import Langevin
from ase.md.nose_hoover_chain import IsotropicMTKNPT
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.verlet import VelocityVerlet
from fairchem.lammps.lammps_fc import run_lammps_with_fairchem

from fairchem.core import FAIRChemCalculator
from fairchem.core.calculate import pretrained_mlip
from fairchem.lammps.lammps_fc import run_lammps_with_fairchem


def run_ase_langevin():
Expand Down Expand Up @@ -74,6 +75,62 @@ def print_thermo(a=atoms):
return atoms.get_kinetic_energy(), atoms.get_potential_energy()


def run_ase_npt():
"""Run ASE NPT-like using a Berendsen barostat approximation via NPT wrapper.

ASE doesn't provide a direct NPT integrator in the core; here we mimic
an NPT run by coupling to a thermostat and using the `Parrinello-Rahman`
style barostat if available in user's setup. For portability in tests we
instead run VelocityVerlet with a simple rescaling of the cell using the
`ase.constraints` is out of scope — this is a lightweight smoke test to
exercise the predictor through an NPT LAMMPS run for comparison.
"""
atoms = bulk("C", "fcc", a=3.567, cubic=True)
atoms = atoms.repeat((2, 2, 2))
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cuda")
atoms.calc = FAIRChemCalculator(predictor, task_name="omat")
initial_temperature_K = 300.0
np.random.seed(12345)
MaxwellBoltzmannDistribution(atoms, temperature_K=initial_temperature_K)
# Use ASE's NPT integrator which couples Nose-Hoover thermostat and
# barostat (Parrinello-Rahman style) and updates the cell. We pick
# thermostat/barostat time constants that map to LAMMPS fix npt's
# Tdamp/Pdamp (units: ps here for LAMMPS). ASE's API expects time in
# fs via ase.units, so use 0.1 ps = 100 fs as the thermostat time constant.
tdamp = 0.1 # ps (thermostat damping time for LAMMPS mapping)
pdamp = 1.0 # ps (barostat damping time for LAMMPS mapping)

# Convert ps -> fs for ASE NPT ttime/pfactor which expect time in fs units
tdamp_fs = tdamp * 1000.0 * units.fs
pdamp_fs = pdamp * 1000.0 * units.fs

# ASE NPT takes timestep in ASE units (seconds via units.fs) and temperature_K
# externalstress is pressure in eV/Å^3 or a scalar (here 0 means 0 pressure)
dyn = IsotropicMTKNPT(
atoms,
timestep=1.0 * units.fs,
temperature_K=300,
pressure_au=0.0 * units.bar,
tdamp=tdamp_fs,
pdamp=pdamp_fs,
)

def print_thermo(a=atoms):
ekin = a.get_kinetic_energy()
epot = a.get_potential_energy()
etot = ekin + epot
temp = ekin / (1.5 * units.kB) / len(a)
vol = a.get_volume()
print(
f"Step: {dyn.get_number_of_steps()}, Temp: {temp:.2f} K, "
f"Ekin: {ekin:.4f} eV, Epot: {epot:.4f} eV, Etot: {etot:.4f} eV, Vol: {vol:.4f} Å^3"
)

dyn.attach(print_thermo, interval=1)
dyn.run(100)
return atoms.get_kinetic_energy(), atoms.get_potential_energy()


def run_lammps(input_file):
predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cuda")
lmp = run_lammps_with_fairchem(predictor, input_file, "omat")
Expand All @@ -88,6 +145,14 @@ def test_ase_vs_lammps_nve():
assert np.isclose(ase_pot, lammps_pot, rtol=0.1)


@pytest.mark.gpu()
def test_ase_vs_lammps_npt():
ase_kinetic, ase_pot = run_ase_npt()
lammps_kinetic, lammps_pot = run_lammps("tests/lammps/lammps_npt.file")
assert np.isclose(ase_kinetic, lammps_kinetic, rtol=0.5)
assert np.isclose(ase_pot, lammps_pot, rtol=0.5)


@pytest.mark.xfail(
reason="This is more demo purposes, need to configure the right parameters for ASE langevin to match lammps"
)
Expand Down