Skip to content

Conversation

@ericyuan00000
Copy link
Contributor

No description provided.

@zulissimeta zulissimeta requested a review from mshuaibii July 14, 2025 23:46
Copy link
Contributor

@zulissimeta zulissimeta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please move most of the logic to a helper function on the predict unit so it's closer to similar logic there. Then we can keep get_hessian focused on the ASE logic.

Also, please add a unit test here for a simple molecule (like water) that you can check the result for. You can re-use the example molecules in the unit tests.


# Batch and predict
batch = data_list_collater([data_object], otf_graph=True)
pred = self.predictor.predict(batch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the logic here to the predict unit?

forces = pred["forces"].flatten()

# Calculate the Hessian using autograd
if vmap:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why branch here? What's the difference in behavior? Why not always use this (i.e. is there any reason to choose vmap=False)?

)[0].flatten().detach().cpu().numpy()

# Turn off create_graph for the first derivative
self.predictor.model.module.output_heads['energyandforcehead'].head.training = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this set after the result is predicted?

@ericyuan00000
Copy link
Contributor Author

Thanks for the comments. I'm working on writing tests now.
The awkward training toggle is necessary because the gradient graph is created depending on the output head's training mode. Enabling training for the entire model is undesirable as it activates dropout. I'm also concerned about the memory and time overhead of gradient graph construction during standard calculate calls (for energy/forces/stress), which is why I disable training again after computing the hessian.
If we move this logic into MLIPPredictUnit, the gradient graph would be created regardless of whether the hessian is being computed. Is that a desirable behavior?
Also, what's the recommended way to implement this logic inside the predict unit? I could add an extra argument to MLIPPredictUnit.__init__ (simple but less clean), or introduce a new key in overrides or inference_settings, though that might require digging into hydra.
The vmap option comes from the CUDA OOM I saw, even for small molecules on 80 GB memory GPUs. The parallelized calculation, though faster, is much more memory intensive than calculating row by row.

@runtian97
Copy link

Hi,

Thanks for the update of hessian calculation. I tried the code below to get hessian:

from ase.io import read
from fairchem.core import FAIRChemCalculator
from fairchem.core.units.mlip_unit import load_predict_unit

predictor = load_predict_unit(
    path="esen_sm_conserving_all.pt",
    device="cpu"
)
calc = FAIRChemCalculator(predictor)

atoms = read('/Users/nickgao/Desktop/pythonProject/local_code_template/test_xyz/CAMVES_a.xyz')  # or create ASE Atoms however you want
atoms.info = {"charge": 0, "spin": 1}
atoms.calc = calc
energy_eV = atoms.get_potential_energy()

hessian = calc.get_hessian(atoms)
print(hessian)

but I later got the error message:

AttributeError Traceback (most recent call last)
Cell In[1], line 16
13 atoms.calc = calc
14 energy_eV = atoms.get_potential_energy()
---> 16 hessian = calc.get_hessian(atoms)
17 print(energy_eV)

File /opt/anaconda3/envs/test/lib/python3.11/site-packages/fairchem/core/calculate/ase_calculator.py:215, in FAIRChemCalculator.get_hessian(self, atoms, vmap)
206 """
207 Get the Hessian matrix for the given atomic structure.
208 Args:
(...) 212 np.ndarray: The Hessian matrix.
213 """
214 # Turn on create_graph for the first derivative
--> 215 self.predictor.model.module.output_heads['energyandforcehead'].head.training = True
217 # Convert using the current a2g object
218 data_object = self.a2g(atoms)

File /opt/anaconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py:1928, in Module.getattr(self, name)
1926 if name in modules:
1927 return modules[name]
-> 1928 raise AttributeError(
1929 f"'{type(self).name}' object has no attribute '{name}'"
1930 )

AttributeError: 'MLP_EFS_Head' object has no attribute 'head'

It seems there is something wrong with this line of code:

self.predictor.model.module.output_heads['energyandforcehead'].head.training = True

Could you please provide a solution to it? Thanks!

@ericyuan00000
Copy link
Contributor Author

To update on the issue, we can consider three main ways of calculating the Hessian:

  1. Analytically in sequence. PyTorch only allows each tensor to have one gradient. In our case, we backprop from one force element to the position tensor and record its grad, then move on from the next force element to the same position tensor, and so on for all 3N force elements. Essentially, we calculate the Hessian matrix row by row as a loop. We do need the create_graph option to be True here to do this second backprop. It will be more memory demanding than the forward call in inference time, not because of the second backprop but because of the gradient graph creation. It's unclear to me whether the memory benefit from the MoE during inference can be applied to the second derivative too.
  2. Analytically in parallel. That's what's implemented in many other MLIPs. We put all the operations in 1 into a vmap. Not having to loop over the 3N force elements means a 3N-fold speedup, but also means a 3N-fold increase in memory demand. A 100-atom system will get CUDA OOM on a single A100 GPU 40GB using FP32.
  3. Numerically in parallel. Since we don't have backprop again, the memory demand is probably the lowest. Finite difference methods can be applied the same way ASE performing vibrational analysis. The 3N (or 6N if using central difference) evaluations are embarrassingly parallel, so we should be able to simply evaluate all the frames simultaneously in one forward call. It may not be as numerically stable as the analytical derivative, so FP64 is strongly preferred.

@zulissimeta zulissimeta requested review from misko and rayg1234 October 16, 2025 13:14
@zulissimeta zulissimeta added enhancement New feature or request minor Minor version release labels Oct 16, 2025
@lbluque
Copy link
Contributor

lbluque commented Oct 16, 2025

Thanks @ericyuan00000 for the detailed description!

I would say there is no need to implement (3) since users can rely on established packages liie ASE or phonopy to compute finite differences.

I think that if we have helper functions for (1) and (2) with a single interface that allows users to chose the implementation would be best.

Thanks again for working on this contribution 😄

@samblau
Copy link

samblau commented Oct 16, 2025

@lbluque My concern with relying on ASE or phonopy for (3) is that they won't batch the gradient calls into a single inference call on the GPU, right? So for an e.g. 100 atom system, rather than just doing one batched inference with 6100100 = 60k atoms, we would end up doing 600 inferences with 100 atoms each in serial, which would presumably be much slower. While certainly the analytic approach is desirable, the fact that we go OOM using approach (2) for a 100 atom system on a 40 GB GPU is not good, as (1) will be 3*N times slower.

My understanding (please correct me if I'm wrong) is that the large memory footprint is coming from the fact that all MOE parameters are active in the gradient graph, as would be necessary for training (right?) but I believe should not be necessary to do second derivative inference. Hence, this is a particular issue for the MoLE setup of UMA.

Thoughts? Thank you!

@lbluque
Copy link
Contributor

lbluque commented Oct 16, 2025

Hi @samblau 👋 Thats a good point! You are correct that using ASE/phonopy will carry out the forward pass for each displacement serially. Batching or parallelization should be straightforward to implement, but again some degree of tuning will be needed to batch all the needed displacements without hitting OOM. My opinion is that this scenario is better suited for a downstream package (torch-sim 👀?).

For both cases (finite differences or autograd) unmerged MOLE models could easily hit OOM. But for both cases we should be merging before inference which should greatly reduce memory footprint. (@ericyuan00000 's implementation should remain the same either way). Though I cant say if that will be enough memory reduction to unblock the use cases you are targeting.

@ericyuan00000
Copy link
Contributor Author

While benchmarking the performance of the methods, I found that I have lost the access of the pos used in the forward call. I believe it's related to this line that makes the data["pos"] for the model here no longer points to batch["pos"] that I need to do autograd with here.

I think returning the cloned positions is somehow necessary, but it doesn't seem to be trivial since there's no corresponding task here and here.

What would be the best way to access the positions now?

@lbluque
Copy link
Contributor

lbluque commented Oct 20, 2025

Thanks for continuing to push on this @ericyuan00000 !

I would simply remove the batch clone line you linked above just for running benchmarks before finalizing the implementation.

My 2c for final implementation would be to move the Hessian calculations to the actual model code (like we do for stress/forces or in the MLIPPredictUnit.predict method.

Any thoughts on this @rayg1234 ?

@lbluque lbluque self-requested a review October 20, 2025 02:28
@ericyuan00000
Copy link
Contributor Author

Thanks @lbluque for the tip. It works now!

I have encountered another issue, though, that my hessian is NaN. It has something to do with this function when taking in values of 1 or -1. I have to change the clamp values at this line. I'm not seeing large changes in model predictions, but I'm not sure if that's always the case. Putting the torch.where operation before the square root and the division could be another possible solution.

I've also benchmarked the performances of the three methods, using Perlmutter's A100 40GB GPU. The lines end when they get CUDA OOM. The system is under pbc so I'd imagine it to be more expensive than non-periodic systems, but the trend should be the same. The analytical calculation with vmap is the fastest, unexpectedly, and the one without vmap is the most memory efficient. The numerical Hessian calculation (with central difference) is not any better than the vmap version memory wise and is significantly slower. I think it'd be safe to conclude that autograd methods (w/ or w/o vmap depending on the memory demands) is the choice.

image

@misko
Copy link
Contributor

misko commented Oct 21, 2025

I have encountered another issue, though, that my hessian is NaN. It has something to do with this function when taking in values of 1 or -1. I have to change the clamp values at this line. I'm not seeing large changes in model predictions, but I'm not sure if that's always the case. Putting the torch.where operation before the square root and the division could be another possible solution.

Hi @ericyuan00000 , you are right, there is some instability when there are y aligned edges in the input :'( I relaxed the tolerance a bit in this region with a new commit (few weeks back), our tests were passing but practically it drives NaNs in the presence of nearly y aligned edges. Thank you for this clamp solution! You are right, clamping in this area fixes the issue. I am working on a non-clamping solution... which will hopefully ship out in the next few days.

@samblau
Copy link

samblau commented Oct 28, 2025

@misko Any update here? I've got two different projects waiting on UMA Hessians haha. Thanks!

@misko
Copy link
Contributor

misko commented Oct 28, 2025

@samblau

@misko Any update here? I've got two different projects waiting on UMA Hessians haha. Thanks!

Its been a drag :'( We have a short and long term solution mapped out. The short term solution is here,
#1574

It looks like its stable, but still waiting on some tests to pass, CI + one more training run. Its not the prettiest but it should unblock once its ready.

The long term solution is to make sure edges are not y aligned to begin with , which is trickier.

@samblau
Copy link

samblau commented Oct 28, 2025

Thanks for the update @misko! You referenced an additional training run - does that mean that this solution won't work directly with UMA-s-1p1?

@misko
Copy link
Contributor

misko commented Oct 28, 2025

Thanks for the update @misko! You referenced an additional training run - does that mean that this solution won't work directly with UMA-s-1p1?

@samblau It should work fine with UMA-s-1p1, our tests are passing all should be good there. Just seeing if we can retrain on a subset and get the same val.

If this is block, you can always try to rotate the system to have non-y-aligned edges if that makes sense. A random rotation usually does the trick.

Something like this should bump everything to be not y aligned,

def random_rotation_transform(data_object: AtomicData, config) -> AtomicData:
    alpha,beta,gamma = torch.rand(3) * 2 * np.pi
    Rz = torch.tensor([[torch.cos(alpha), -torch.sin(alpha), 0],
                     [torch.sin(alpha), torch.cos(alpha), 0],
                        [0, 0, 1]])
    Ry = torch.tensor([[torch.cos(beta), 0, torch.sin(beta)],
                   [0, 1, 0],
                   [-torch.sin(beta), 0, torch.cos(beta)]])
    Rx = torch.tensor([[1, 0, 0],
                   [0, torch.cos(gamma), -torch.sin(gamma)],
                   [0, torch.sin(gamma), torch.cos(gamma)]])
    R = Rz @ Ry @ Rx
    for k in data_object.keys():  # noqa: SIM118
        if "forces" in k:
            data_object[k] = data_object[k] @ R.T
        if "stress" in k and ("iso" not in k and "aniso" not in k):
            data_object[k] = (R @ data_object[k].reshape(3,3) @ R.T).reshape(1,9)
        if k == "pos":
            data_object.pos = data_object.pos @ R.T
        if k =='cell':
            data_object.cell = data_object.cell @ R.T
    return data_object

@samblau
Copy link

samblau commented Oct 28, 2025

@misko That feels extremely hacky haha. If you think that we'll get an official solution merged soon such that we can wrap up the PR e.g. end of this week or early next week, then I can be patient. I'd really prefer not to be messing with stuff at the torch level since we're going to be calculating Hessians with UMA for thousands of systems.

@misko
Copy link
Contributor

misko commented Oct 28, 2025

@misko That feels extremely hacky haha. If you think that we'll get an official solution merged soon such that we can wrap up the PR e.g. end of this week or early next week, then I can be patient. I'd really prefer not to be messing with stuff at the torch level since we're going to be calculating Hessians with UMA for thousands of systems.

I think it should land tomorrow. Pretty confident it will be in by end of week. There's a chance there still might be some issues that are hard to detect with y-aligned edges. If after the fix you have any issues with a system containing a y-aligned edge, please post an issue and I will do my best to prioritize the longer term fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed enhancement New feature or request minor Minor version release

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants