diff --git a/src/fairchem/lammps/lammps_fc.py b/src/fairchem/lammps/lammps_fc.py index 88f2690051..8cb74d7457 100644 --- a/src/fairchem/lammps/lammps_fc.py +++ b/src/fairchem/lammps/lammps_fc.py @@ -8,6 +8,7 @@ import numpy as np import torch from ase.data import atomic_masses, chemical_symbols +from ase.geometry import wrap_positions from fairchem.core.datasets.atomic_data import AtomicData from lammps import lammps @@ -32,16 +33,23 @@ def check_input_script(input_script: str): def check_atom_id_match_masses(types_arr, masses): for atom_id in types_arr: - assert np.allclose( - masses[atom_id], atomic_masses[atom_id], atol=1e-1 - ), f"Atom {chemical_symbols[atom_id]} (type {atom_id}) has mass {masses[atom_id]} but is expected to have mass {atomic_masses[atom_id]}." + assert np.allclose(masses[atom_id], atomic_masses[atom_id], atol=1e-1), ( + f"Atom {chemical_symbols[atom_id]} (type {atom_id}) has mass {masses[atom_id]} but is expected to have mass {atomic_masses[atom_id]}." + ) def atomic_data_from_lammps_data( - x, atomic_numbers, nlocal, cell, periodicity, task_name + x: np.ndarray | torch.Tensor, + atomic_numbers, + nlocal, + cell, + periodicity, + task_name, + charge: int = 0, + spin: int = 0, ): # TODO: do we need to take of care of wrapping atoms that are outside the cell? - pos = torch.tensor(x, dtype=torch.float32) + pos = torch.as_tensor(x, dtype=torch.float32) pbc = torch.tensor(periodicity, dtype=torch.bool).unsqueeze(0) edge_index = torch.empty((2, 0), dtype=torch.long) cell_offsets = torch.empty((0, 3), dtype=torch.float32) @@ -58,8 +66,8 @@ def atomic_data_from_lammps_data( edge_index=edge_index, cell_offsets=cell_offsets, nedges=nedges, - charge=torch.LongTensor([0]), - spin=torch.LongTensor([0]), + charge=torch.LongTensor([charge]), + spin=torch.LongTensor([spin]), fixed=fixed, tags=tags, batch=batch, @@ -116,7 +124,7 @@ def lookup_atomic_number_by_mass(mass_arr: np.ndarray | float) -> np.ndarray | i return atomic_numbers -def separate_run_commands(input_script: str) -> str: +def separate_run_commands(input_script: str) -> tuple[list[str], list[str]]: lines = input_script.splitlines() run_cmds = [] script = [] @@ -145,37 +153,58 @@ def cell_from_lammps_box(boxlo, boxhi, xy, yz, xz): return unit_cell_matrix.unsqueeze(0) -def fix_external_call_back(lmp, ntimestep, nlocal, tag, x, f): - # force copy here, otherwise we can accident modify the original array in lammps - # TODO: only need to get atomic numbers once and cache it? - # is there a way to check atom types are mapped correctly? - atom_type_np = lmp.numpy.extract_atom("type") - masses = lmp.numpy.extract_atom("mass") - atomic_mass_arr = masses[atom_type_np] - atomic_numbers = lookup_atomic_number_by_mass(atomic_mass_arr) - boxlo, boxhi, xy, yz, xz, periodicity, box_change = lmp.extract_box() - cell = cell_from_lammps_box(boxlo, boxhi, xy, yz, xz) - atomic_data = atomic_data_from_lammps_data( - x, atomic_numbers, nlocal, cell, periodicity, lmp._task_name - ) - results = lmp._predictor.predict(atomic_data) - assert "forces" in results, "forces must be in results" - f[:] = results["forces"].cpu().numpy()[:] - lmp.fix_external_set_energy_global(FIX_EXT_ID, results["energy"].item()) - - # during NPT for example, box_change should be set to 1 by lammps to allow the cell to change - if box_change: - # stress is defined as virial/volume in lammps - assert "stress" in results, "stress must be in results to compute virial" - volume = torch.det(cell).abs().item() - v = (results["stress"].cpu() * volume)[0] - # virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd - virial_arr = [v[0], v[4], v[8], v[1], v[2], v[5]] - lmp.fix_external_set_virial_global(FIX_EXT_ID, virial_arr) +class FixExternalCallback: + def __init__(self, charge: int = 0, spin: int = 0): + self.charge = charge + self.spin = spin + + def __call__(self, lmp, ntimestep, nlocal, tag, x, f): + # force copy here, otherwise we can accident modify the original array in lammps + # TODO: only need to get atomic numbers once and cache it? + # is there a way to check atom types are mapped correctly? + atom_type_np = lmp.numpy.extract_atom("type") + masses = lmp.numpy.extract_atom("mass") + atomic_mass_arr = masses[atom_type_np] + atomic_numbers = lookup_atomic_number_by_mass(atomic_mass_arr) + boxlo, boxhi, xy, yz, xz, periodicity, box_change = lmp.extract_box() + cell = cell_from_lammps_box(boxlo, boxhi, xy, yz, xz) + + x_wrapped = wrap_positions( + x, cell=cell.squeeze().numpy(), pbc=periodicity, eps=0 + ) + + atomic_data = atomic_data_from_lammps_data( + x_wrapped, + atomic_numbers, + nlocal, + cell, + periodicity, + lmp._task_name, + charge=self.charge, + spin=self.spin, + ) + results = lmp._predictor.predict(atomic_data) + assert "forces" in results, "forces must be in results" + f[:] = results["forces"].cpu().numpy()[:] + lmp.fix_external_set_energy_global(FIX_EXT_ID, results["energy"].item()) + + # during NPT for example, box_change should be set to 1 by lammps to allow the cell to change + if box_change: + # stress is defined as -virial/volume in lammps + assert "stress" in results, "stress must be in results to compute virial" + volume = torch.det(cell).abs().item() + v = (-results["stress"].detach().cpu() * volume)[0].tolist() + # virials need to be in this order: xx, yy, zz, xy, xz, yz. https://docs.lammps.org/Library_utility.html#_CPPv437lammps_fix_external_set_virial_globalPvPKcPd + virial_arr = [v[0], v[4], v[8], v[1], v[2], v[5]] + lmp.fix_external_set_virial_global(FIX_EXT_ID, virial_arr) def run_lammps_with_fairchem( - predictor: MLIPPredictUnitProtocol, lammps_input_path: str, task_name: str + predictor: MLIPPredictUnitProtocol, + lammps_input_path: str, + task_name: str, + charge: int = 0, + spin: int = 0, ): machine = None if "LAMMPS_MACHINE_NAME" in os.environ: @@ -183,7 +212,7 @@ def run_lammps_with_fairchem( lmp = lammps(name=machine, cmdargs=["-nocite", "-log", "none", "-echo", "screen"]) lmp._predictor = predictor lmp._task_name = task_name - run_cmds = [] + # run_cmds = [] with open(lammps_input_path) as f: input_script = f.read() check_input_script(input_script) @@ -191,6 +220,7 @@ def run_lammps_with_fairchem( logging.info(f"Running input script: {input_script}") lmp.commands_list(script) lmp.command(FIX_EXTERNAL_CMD) + fix_external_call_back = FixExternalCallback(charge=charge, spin=spin) lmp.set_fix_external_callback(FIX_EXT_ID, fix_external_call_back, lmp) lmp.commands_list(run_cmds) return lmp @@ -203,7 +233,13 @@ def run_lammps_with_fairchem( ) def main(cfg: DictConfig): predict_unit = hydra.utils.instantiate(cfg.predict_unit) - lmp = run_lammps_with_fairchem(predict_unit, cfg.lmp_in, cfg.task_name) + lmp = run_lammps_with_fairchem( + predict_unit, + cfg.lmp_in, + cfg.task_name, + cfg.charge, + cfg.spin, + ) # this is required to cleanup the predictor del lmp._predictor diff --git a/src/fairchem/lammps/lammps_fc_config.yaml b/src/fairchem/lammps/lammps_fc_config.yaml index b0b48324e7..0a00e0522f 100644 --- a/src/fairchem/lammps/lammps_fc_config.yaml +++ b/src/fairchem/lammps/lammps_fc_config.yaml @@ -30,3 +30,5 @@ parallel_predict_unit: predict_unit: ${local_predict_unit} lmp_in: "lammps_in_example.file" task_name: "omol" +charge: 0 +spin: 0 diff --git a/tests/lammps/lammps_npt.file b/tests/lammps/lammps_npt.file new file mode 100644 index 0000000000..f8d49845e3 --- /dev/null +++ b/tests/lammps/lammps_npt.file @@ -0,0 +1,19 @@ +units metal # Use metal units (Angstroms, eV, ps) +atom_style atomic # Atoms have a single type and position +lattice fcc 3.567 +boundary p p p + +region simbox block 0 2 0 2 0 2 +create_box 1 simbox +create_atoms 1 region simbox +mass 1 12.011 + +velocity all create 300.0 12345 dist gaussian # Set initial velocities at 300 K + +timestep 0.001 # 1 fs +# Use NPT (isotropic) thermostat+barostat: target temp 300 K, target pressure 0 bar +# fix npt syntax: fix ID group-ID npt temp Tstart Tstop Tdamp iso Pstart Pstop Pdamp +fix 1 all npt temp 300.0 300.0 0.1 iso 0.0 0.0 1.0 +thermo_style custom step temp pe ke etotal press vol +thermo 1 +run 100 diff --git a/tests/lammps/test_ase_vs_lammps.py b/tests/lammps/test_ase_vs_lammps.py index d8ad7df032..444f81523a 100644 --- a/tests/lammps/test_ase_vs_lammps.py +++ b/tests/lammps/test_ase_vs_lammps.py @@ -5,12 +5,13 @@ from ase import units from ase.build import bulk from ase.md.langevin import Langevin +from ase.md.nose_hoover_chain import IsotropicMTKNPT from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.md.verlet import VelocityVerlet -from fairchem.lammps.lammps_fc import run_lammps_with_fairchem from fairchem.core import FAIRChemCalculator from fairchem.core.calculate import pretrained_mlip +from fairchem.lammps.lammps_fc import run_lammps_with_fairchem def run_ase_langevin(): @@ -74,6 +75,62 @@ def print_thermo(a=atoms): return atoms.get_kinetic_energy(), atoms.get_potential_energy() +def run_ase_npt(): + """Run ASE NPT-like using a Berendsen barostat approximation via NPT wrapper. + + ASE doesn't provide a direct NPT integrator in the core; here we mimic + an NPT run by coupling to a thermostat and using the `Parrinello-Rahman` + style barostat if available in user's setup. For portability in tests we + instead run VelocityVerlet with a simple rescaling of the cell using the + `ase.constraints` is out of scope — this is a lightweight smoke test to + exercise the predictor through an NPT LAMMPS run for comparison. + """ + atoms = bulk("C", "fcc", a=3.567, cubic=True) + atoms = atoms.repeat((2, 2, 2)) + predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cuda") + atoms.calc = FAIRChemCalculator(predictor, task_name="omat") + initial_temperature_K = 300.0 + np.random.seed(12345) + MaxwellBoltzmannDistribution(atoms, temperature_K=initial_temperature_K) + # Use ASE's NPT integrator which couples Nose-Hoover thermostat and + # barostat (Parrinello-Rahman style) and updates the cell. We pick + # thermostat/barostat time constants that map to LAMMPS fix npt's + # Tdamp/Pdamp (units: ps here for LAMMPS). ASE's API expects time in + # fs via ase.units, so use 0.1 ps = 100 fs as the thermostat time constant. + tdamp = 0.1 # ps (thermostat damping time for LAMMPS mapping) + pdamp = 1.0 # ps (barostat damping time for LAMMPS mapping) + + # Convert ps -> fs for ASE NPT ttime/pfactor which expect time in fs units + tdamp_fs = tdamp * 1000.0 * units.fs + pdamp_fs = pdamp * 1000.0 * units.fs + + # ASE NPT takes timestep in ASE units (seconds via units.fs) and temperature_K + # externalstress is pressure in eV/Å^3 or a scalar (here 0 means 0 pressure) + dyn = IsotropicMTKNPT( + atoms, + timestep=1.0 * units.fs, + temperature_K=300, + pressure_au=0.0 * units.bar, + tdamp=tdamp_fs, + pdamp=pdamp_fs, + ) + + def print_thermo(a=atoms): + ekin = a.get_kinetic_energy() + epot = a.get_potential_energy() + etot = ekin + epot + temp = ekin / (1.5 * units.kB) / len(a) + vol = a.get_volume() + print( + f"Step: {dyn.get_number_of_steps()}, Temp: {temp:.2f} K, " + f"Ekin: {ekin:.4f} eV, Epot: {epot:.4f} eV, Etot: {etot:.4f} eV, Vol: {vol:.4f} Å^3" + ) + + dyn.attach(print_thermo, interval=1) + dyn.run(100) + return atoms.get_kinetic_energy(), atoms.get_potential_energy() + + def run_lammps(input_file): predictor = pretrained_mlip.get_predict_unit("uma-s-1p1", device="cuda") lmp = run_lammps_with_fairchem(predictor, input_file, "omat") @@ -88,6 +145,14 @@ def test_ase_vs_lammps_nve(): assert np.isclose(ase_pot, lammps_pot, rtol=0.1) +@pytest.mark.gpu() +def test_ase_vs_lammps_npt(): + ase_kinetic, ase_pot = run_ase_npt() + lammps_kinetic, lammps_pot = run_lammps("tests/lammps/lammps_npt.file") + assert np.isclose(ase_kinetic, lammps_kinetic, rtol=0.5) + assert np.isclose(ase_pot, lammps_pot, rtol=0.5) + + @pytest.mark.xfail( reason="This is more demo purposes, need to configure the right parameters for ASE langevin to match lammps" )