Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions src/fairchem/core/calculate/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
from typing import TYPE_CHECKING, Literal

import numpy as np
import torch
from ase.calculators.calculator import Calculator
from ase.stress import full_3x3_to_voigt_6_stress
from torch.autograd import grad

from fairchem.core.calculate import pretrained_mlip
from fairchem.core.datasets import data_list_collater
Expand Down Expand Up @@ -225,6 +227,115 @@ def calculate(
stress_voigt = full_3x3_to_voigt_6_stress(stress)
self.results["stress"] = stress_voigt

def get_hessian(self, atoms: Atoms, vmap: bool = True) -> np.ndarray:
"""
Get the Hessian matrix for the given atomic structure.

Args:
atoms (Atoms): The atomic structure to calculate the Hessian for.
vmap (bool): Whether to use vectorized mapping for Hessian calculation. Defaults to True.

Returns:
np.ndarray: The Hessian matrix.
"""
# Turn on create_graph for the first derivative
self.predictor.model.module.output_heads[
"energyandforcehead"
].head.training = True

# Convert using the current a2g object
data_object = self.a2g(atoms)

# Batch and predict
batch = data_list_collater([data_object], otf_graph=True)
pred = self.predictor.predict(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the logic here to the predict unit?


# Get the forces and positions
positions = batch["pos"]
forces = pred["forces"].flatten()

# Calculate the Hessian using autograd
if vmap:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why branch here? What's the difference in behavior? Why not always use this (i.e. is there any reason to choose vmap=False)?

hessian = (
torch.vmap(
lambda vec: grad(
-forces,
positions,
grad_outputs=vec,
retain_graph=True,
)[0],
)(torch.eye(forces.numel(), device=forces.device))
.detach()
.cpu()
.numpy()
)
else:
hessian = np.zeros((len(forces), len(forces)))
for i in range(len(forces)):
hessian[:, i] = (
grad(
-forces[i],
positions,
retain_graph=True,
)[0]
.flatten()
.detach()
.cpu()
.numpy()
)

# Turn off create_graph for the first derivative
self.predictor.model.module.output_heads[
"energyandforcehead"
].head.training = False

return hessian.reshape(len(atoms) * 3, len(atoms) * 3)

def get_numerical_hessian(self, atoms: Atoms, eps: float = 1e-4) -> np.ndarray:
"""
Get the Hessian matrix for the given atomic structure.

Args:
atoms (Atoms): The atomic structure to calculate the Hessian for.
eps (float): The finite difference step size. Defaults to 1e-4.

Returns:
np.ndarray: The Hessian matrix.
"""
# Create displaced atoms in batch
data_list = []
for i in range(len(atoms)):
for j in range(3):
displaced_plus = atoms.copy()
displaced_minus = atoms.copy()

displaced_plus.positions[i, j] += eps
displaced_minus.positions[i, j] -= eps

data_plus = self.a2g(displaced_plus)
data_minus = self.a2g(displaced_minus)

data_list.append(data_plus)
data_list.append(data_minus)

# Batch and predict
batch = data_list_collater(data_list, otf_graph=True)
pred = self.predictor.predict(batch)

# Get the forces
forces = pred["forces"].reshape(-1, len(atoms), 3)

# Calculate the Hessian using finite differences
hessian = np.zeros((len(atoms) * 3, len(atoms) * 3))
for i in range(len(atoms)):
for j in range(3):
idx = i * 3 + j
force_plus = forces[2 * idx].flatten().detach().cpu().numpy()
force_minus = forces[2 * idx + 1].flatten().detach().cpu().numpy()
hessian[:, idx] = (force_minus - force_plus) / (2 * eps)

return hessian

def _get_single_atom_energies(self, atoms) -> dict:
"""
Populate output with single atom energies
Expand Down
2 changes: 1 addition & 1 deletion src/fairchem/core/units/mlip_unit/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def predict(
)

# this needs to be .clone() to avoid issues with graph parallel modifying this data with MOLE
data_device = data.to(self.device).clone()
data_device = data.to(self.device) # .clone()

if self.inference_mode.merge_mole:
if self.merged_on is None:
Expand Down
18 changes: 18 additions & 0 deletions tests/core/calculate/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,21 @@ def test_parallel_md(checkpointing):

calc = FAIRChemCalculator(predictor, task_name="omol")
run_md_simulation(calc, steps=10)


@pytest.mark.parametrize("vmap", [True, False])
def test_hessian(vmap):
atoms = molecule("H2O")
calc = FAIRChemCalculator(predictor, task_name="omol")

hessian = calc.get_hessian(atoms, vmap=vmap)
assert np.isfinite(hessian).all()


def test_numerical_hessian():
atoms = molecule("H2O")
calc = FAIRChemCalculator(predictor, task_name="omol")

hessian = calc.get_numerical_hessian(atoms)
assert np.isfinite(hessian).all()