-
Notifications
You must be signed in to change notification settings - Fork 405
Add hessian calculation in calculator #1361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ericyuan00000
wants to merge
11
commits into
facebookresearch:main
Choose a base branch
from
ericyuan00000:patch-4
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
22d4f74
Add hessian calculation in calculator
ericyuan00000 d377acf
Modify atoms
ericyuan00000 0f3c204
Modify atoms
ericyuan00000 a4188c4
Merge branch 'main' into patch-4
zulissimeta 2ec857c
Merge branch 'main' into patch-4
zulissimeta 6f45aca
Update hessian calculations
ericyuan00000 07284bd
Merge branch 'main' into patch-4
ericyuan00000 c949ba7
update_lint
ericyuan00000 6377e1c
update_test
ericyuan00000 c8f9396
update_lint
ericyuan00000 60c88d0
update_lint
ericyuan00000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,8 +13,10 @@ | |
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from ase.calculators.calculator import Calculator | ||
| from ase.stress import full_3x3_to_voigt_6_stress | ||
| from torch.autograd import grad | ||
|
|
||
| from fairchem.core.calculate import pretrained_mlip | ||
| from fairchem.core.datasets import data_list_collater | ||
|
|
@@ -225,6 +227,115 @@ def calculate( | |
| stress_voigt = full_3x3_to_voigt_6_stress(stress) | ||
| self.results["stress"] = stress_voigt | ||
|
|
||
| def get_hessian(self, atoms: Atoms, vmap: bool = True) -> np.ndarray: | ||
| """ | ||
| Get the Hessian matrix for the given atomic structure. | ||
|
|
||
| Args: | ||
| atoms (Atoms): The atomic structure to calculate the Hessian for. | ||
| vmap (bool): Whether to use vectorized mapping for Hessian calculation. Defaults to True. | ||
|
|
||
| Returns: | ||
| np.ndarray: The Hessian matrix. | ||
| """ | ||
| # Turn on create_graph for the first derivative | ||
| self.predictor.model.module.output_heads[ | ||
| "energyandforcehead" | ||
| ].head.training = True | ||
|
|
||
| # Convert using the current a2g object | ||
| data_object = self.a2g(atoms) | ||
|
|
||
| # Batch and predict | ||
| batch = data_list_collater([data_object], otf_graph=True) | ||
| pred = self.predictor.predict(batch) | ||
|
|
||
| # Get the forces and positions | ||
| positions = batch["pos"] | ||
| forces = pred["forces"].flatten() | ||
|
|
||
| # Calculate the Hessian using autograd | ||
| if vmap: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why branch here? What's the difference in behavior? Why not always use this (i.e. is there any reason to choose vmap=False)? |
||
| hessian = ( | ||
| torch.vmap( | ||
| lambda vec: grad( | ||
| -forces, | ||
| positions, | ||
| grad_outputs=vec, | ||
| retain_graph=True, | ||
| )[0], | ||
| )(torch.eye(forces.numel(), device=forces.device)) | ||
| .detach() | ||
| .cpu() | ||
| .numpy() | ||
| ) | ||
| else: | ||
| hessian = np.zeros((len(forces), len(forces))) | ||
| for i in range(len(forces)): | ||
| hessian[:, i] = ( | ||
| grad( | ||
| -forces[i], | ||
| positions, | ||
| retain_graph=True, | ||
| )[0] | ||
| .flatten() | ||
| .detach() | ||
| .cpu() | ||
| .numpy() | ||
| ) | ||
|
|
||
| # Turn off create_graph for the first derivative | ||
| self.predictor.model.module.output_heads[ | ||
| "energyandforcehead" | ||
| ].head.training = False | ||
|
|
||
| return hessian.reshape(len(atoms) * 3, len(atoms) * 3) | ||
|
|
||
| def get_numerical_hessian(self, atoms: Atoms, eps: float = 1e-4) -> np.ndarray: | ||
| """ | ||
| Get the Hessian matrix for the given atomic structure. | ||
|
|
||
| Args: | ||
| atoms (Atoms): The atomic structure to calculate the Hessian for. | ||
| eps (float): The finite difference step size. Defaults to 1e-4. | ||
|
|
||
| Returns: | ||
| np.ndarray: The Hessian matrix. | ||
| """ | ||
| # Create displaced atoms in batch | ||
| data_list = [] | ||
| for i in range(len(atoms)): | ||
| for j in range(3): | ||
| displaced_plus = atoms.copy() | ||
| displaced_minus = atoms.copy() | ||
|
|
||
| displaced_plus.positions[i, j] += eps | ||
| displaced_minus.positions[i, j] -= eps | ||
|
|
||
| data_plus = self.a2g(displaced_plus) | ||
| data_minus = self.a2g(displaced_minus) | ||
|
|
||
| data_list.append(data_plus) | ||
| data_list.append(data_minus) | ||
|
|
||
| # Batch and predict | ||
| batch = data_list_collater(data_list, otf_graph=True) | ||
| pred = self.predictor.predict(batch) | ||
|
|
||
| # Get the forces | ||
| forces = pred["forces"].reshape(-1, len(atoms), 3) | ||
|
|
||
| # Calculate the Hessian using finite differences | ||
| hessian = np.zeros((len(atoms) * 3, len(atoms) * 3)) | ||
| for i in range(len(atoms)): | ||
| for j in range(3): | ||
| idx = i * 3 + j | ||
| force_plus = forces[2 * idx].flatten().detach().cpu().numpy() | ||
| force_minus = forces[2 * idx + 1].flatten().detach().cpu().numpy() | ||
| hessian[:, idx] = (force_minus - force_plus) / (2 * eps) | ||
|
|
||
| return hessian | ||
|
|
||
| def _get_single_atom_energies(self, atoms) -> dict: | ||
| """ | ||
| Populate output with single atom energies | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you move the logic here to the predict unit?