Skip to content

Commit af1bee4

Browse files
CompRhysjanosh
andauthored
Updates integrator logic per #171. (#175)
* fea: updates integrator logic per #171. The tests of the function do not use the row vector cell so this suggests this is a bug that it's used in the integrators * change test_compare_single_vs_batched_integrators to use titled cell as suggested by @Seungwoo-Hwang #171 (comment) * fix pbc_wrap_general and pbc_wrap_batched in transforms.py to conform to row vector convention (M_row) for lattice vectors - improve docstrings to clarify assumptions and formulas used in periodic boundary condition calculations * new unit tests in test_transforms.py to ensure consistency with the new row vector approach, including manual wrapping calculations as to get reference values * fix: add casio3_sim_state triclinic sim state fixture, revert erroneous changes to wrap pbc * git: wanted to revert changes in transforms but they got lost in stash/pull * fix: revert other test change * fix: unbatched integrators, use modulo * fix: use new positions not state positions to mirror logic in batched. * fix basis-symbols mismatch in tio2_sim_state * cleanup test_pbc_wrap_general_param cases * test_pbc_wrap_general_param add more edge cases --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 45e9676 commit af1bee4

File tree

7 files changed

+127
-107
lines changed

7 files changed

+127
-107
lines changed

tests/conftest.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,10 @@ def ti_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
127127
def tio2_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
128128
"""Create crystalline TiO2 using ASE."""
129129
a, c = 4.60, 2.96
130-
symbols = ["Ti", "O", "O"]
131-
basis = [
132-
(0.5, 0.5, 0), # Ti
133-
(0.695679, 0.695679, 0.5), # O
134-
]
130+
basis = [("Ti", 0.5, 0.5, 0), ("O", 0.695679, 0.695679, 0.5)]
135131
atoms = crystal(
136-
symbols,
137-
basis=basis,
132+
symbols=[b[0] for b in basis],
133+
basis=[b[1:] for b in basis],
138134
spacegroup=136, # P4_2/mnm
139135
cellpar=[a, a, c, 90, 90, 90],
140136
)
@@ -145,13 +141,10 @@ def tio2_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
145141
def ga_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
146142
"""Create crystalline Ga using ASE."""
147143
a, b, c = 4.43, 7.60, 4.56
148-
symbols = ["Ga"]
149-
basis = [
150-
(0, 0.344304, 0.415401), # Ga
151-
]
144+
basis = [("Ga", 0, 0.344304, 0.415401)]
152145
atoms = crystal(
153-
symbols,
154-
basis=basis,
146+
symbols=[b[0] for b in basis],
147+
basis=[b[1:] for b in basis],
155148
spacegroup=64, # Cmce
156149
cellpar=[a, b, c, 90, 90, 90],
157150
)
@@ -163,14 +156,13 @@ def niti_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
163156
"""Create crystalline NiTi using ASE."""
164157
a, b, c = 2.89, 3.97, 4.83
165158
alpha, beta, gamma = 90.00, 105.23, 90.00
166-
symbols = ["Ni", "Ti"]
167159
basis = [
168-
(0.369548, 0.25, 0.217074), # Ni
169-
(0.076622, 0.25, 0.671102), # Ti
160+
("Ni", 0.369548, 0.25, 0.217074),
161+
("Ti", 0.076622, 0.25, 0.671102),
170162
]
171163
atoms = crystal(
172-
symbols,
173-
basis=basis,
164+
symbols=[b[0] for b in basis],
165+
basis=[b[1:] for b in basis],
174166
spacegroup=11,
175167
cellpar=[a, b, c, alpha, beta, gamma],
176168
)
@@ -215,6 +207,36 @@ def rattled_sio2_sim_state(
215207
return sim_state
216208

217209

210+
@pytest.fixture
211+
def casio3_sim_state(device: torch.device, dtype: torch.dtype) -> SimState:
212+
a, b, c = 7.9258, 7.3202, 7.0653
213+
alpha, beta, gamma = 90.055, 95.217, 103.426
214+
basis = [
215+
("Ca", 0.19831, 0.42266, 0.76060),
216+
("Ca", 0.20241, 0.92919, 0.76401),
217+
("Ca", 0.50333, 0.75040, 0.52691),
218+
("Si", 0.1851, 0.3875, 0.2684),
219+
("Si", 0.1849, 0.9542, 0.2691),
220+
("Si", 0.3973, 0.7236, 0.0561),
221+
("O", 0.3034, 0.4616, 0.4628),
222+
("O", 0.3014, 0.9385, 0.4641),
223+
("O", 0.5705, 0.7688, 0.1988),
224+
("O", 0.9832, 0.3739, 0.2655),
225+
("O", 0.9819, 0.8677, 0.2648),
226+
("O", 0.4018, 0.7266, 0.8296),
227+
("O", 0.2183, 0.1785, 0.2254),
228+
("O", 0.2713, 0.8704, 0.0938),
229+
("O", 0.2735, 0.5126, 0.0931),
230+
]
231+
atoms = crystal(
232+
symbols=[b[0] for b in basis],
233+
basis=[b[1:] for b in basis],
234+
spacegroup=2,
235+
cellpar=[a, b, c, alpha, beta, gamma],
236+
)
237+
return ts.io.atoms_to_state(atoms, device, dtype)
238+
239+
218240
@pytest.fixture
219241
def benzene_sim_state(
220242
benzene_atoms: Any, device: torch.device, dtype: torch.dtype

tests/models/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"rattled_sio2_sim_state",
2828
"ar_supercell_sim_state",
2929
"fe_supercell_sim_state",
30+
"casio3_sim_state",
3031
"benzene_sim_state",
3132
)
3233

tests/test_integrators.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any
22

3+
import pytest
34
import torch
45

56
from torch_sim.integrators import (
@@ -317,13 +318,23 @@ def test_nve(ar_double_sim_state: SimState, lj_model: LennardJonesModel):
317318
assert torch.allclose(energies_tensor[:, 1], energies_tensor[0, 1], atol=1e-4)
318319

319320

321+
@pytest.mark.parametrize(
322+
"sim_state_fixture_name", ["casio3_sim_state", "ar_supercell_sim_state"]
323+
)
320324
def test_compare_single_vs_batched_integrators(
321-
ar_supercell_sim_state: SimState, lj_model: Any
325+
sim_state_fixture_name: str, request: pytest.FixtureRequest, lj_model: Any
322326
) -> None:
323-
"""Test that single and batched integrators give the same results."""
327+
"""Test NVE single vs batched for a tilted cell to verify PBC wrapping.
328+
329+
NOTE: added triclinic cell after https://github.com/Radical-AI/torch-sim/issues/171.
330+
Although the addition doesn't fail if we do not add the changes suggested in issue.
331+
"""
332+
sim_state = request.getfixturevalue(sim_state_fixture_name)
333+
n_steps = 100
334+
324335
initial_states = {
325-
"single": ar_supercell_sim_state,
326-
"batched": concatenate_states([ar_supercell_sim_state, ar_supercell_sim_state]),
336+
"single": sim_state,
337+
"batched": concatenate_states([sim_state, sim_state]),
327338
}
328339

329340
final_states = {}
@@ -333,25 +344,31 @@ def test_compare_single_vs_batched_integrators(
333344
dt = torch.tensor(0.001) # Small timestep for stability
334345

335346
nve_init, nve_update = nve(model=lj_model, dt=dt, kT=kT)
336-
state = nve_init(state=state, seed=42)
337-
state.momenta = torch.zeros_like(state.momenta)
347+
# Initialize momenta (even if zero) and get forces
348+
state = nve_init(state=state, seed=42) # kT is ignored if momenta are set below
349+
# Ensure momenta start at zero AFTER init which might randomize them based on kT
350+
state.momenta = torch.zeros_like(state.momenta) # Start from rest
338351

339-
for _step in range(100):
352+
for _step in range(n_steps):
340353
state = nve_update(state=state, dt=dt)
341354

342355
final_states[state_name] = state
343356

344357
# Check energy conservation
345-
ar_single_state = final_states["single"]
346-
ar_batched_state_0 = final_states["batched"][0]
347-
ar_batched_state_1 = final_states["batched"][1]
348-
349-
for final_state in [ar_batched_state_0, ar_batched_state_1]:
350-
assert torch.allclose(ar_single_state.positions, final_state.positions)
351-
assert torch.allclose(ar_single_state.momenta, final_state.momenta)
352-
assert torch.allclose(ar_single_state.forces, final_state.forces)
353-
assert torch.allclose(ar_single_state.masses, final_state.masses)
354-
assert torch.allclose(ar_single_state.cell, final_state.cell)
358+
single_state = final_states["single"]
359+
batched_state_0 = final_states["batched"][0]
360+
batched_state_1 = final_states["batched"][1]
361+
362+
# Compare single state results with each part of the batched state
363+
for final_state in [batched_state_0, batched_state_1]:
364+
# Check positions first - most likely to fail with incorrect PBC
365+
torch.testing.assert_close(single_state.positions, final_state.positions)
366+
# Check other state components
367+
torch.testing.assert_close(single_state.momenta, final_state.momenta)
368+
torch.testing.assert_close(single_state.forces, final_state.forces)
369+
torch.testing.assert_close(single_state.masses, final_state.masses)
370+
torch.testing.assert_close(single_state.cell, final_state.cell)
371+
torch.testing.assert_close(single_state.energy, final_state.energy)
355372

356373

357374
def test_compute_cell_force_atoms_per_batch():

tests/test_transforms.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -94,30 +94,34 @@ def test_pbc_wrap_general_orthorhombic() -> None:
9494
assert torch.allclose(wrapped, expected)
9595

9696

97-
def test_pbc_wrap_general_triclinic() -> None:
98-
"""Test periodic boundary wrapping with triclinic cell.
99-
100-
Tests wrapping in a non-orthogonal cell where lattice vectors have
101-
off-diagonal components (tilt factors). This verifies the general
102-
matrix transformation approach works for arbitrary cell shapes.
103-
"""
104-
# Triclinic cell with tilt
105-
lattice = torch.tensor(
106-
[
107-
[2.0, 0.5, 0.0], # a vector with b-tilt
108-
[0.0, 2.0, 0.0], # b vector
109-
[0.0, 0.3, 2.0], # c vector with b-tilt
110-
]
111-
)
112-
113-
# Position outside triclinic box
114-
positions = torch.tensor([[2.5, 2.5, 2.5]])
115-
116-
# Correct expected wrapped position for this triclinic cell
117-
expected = torch.tensor([[2.0, 0.5, 0.2]])
118-
119-
wrapped = tst.pbc_wrap_general(positions, lattice)
120-
assert torch.allclose(wrapped, expected, atol=1e-6)
97+
@pytest.mark.parametrize(
98+
("cell", "shift"),
99+
[
100+
# Cubic cell, integer shift [1, 1, 1]
101+
(torch.eye(3, dtype=torch.float64) * 2.0, [1, 1, 1]),
102+
# Triclinic cell, integer shift [1, 1, 1]
103+
(([[2.0, 0.0, 0.0], [0.5, 2.0, 0.0], [0.0, 0.3, 2.0]]), [1, 1, 1]),
104+
# Triclinic cell, integer shift [-1, 2, 0]
105+
(([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-1, 2, 0]),
106+
# triclinic, all negative shift
107+
(([[2.0, 0.5, 0.0], [0.0, 2.0, 0.0], [0.0, 0.3, 2.0]]), [-2, -1, -3]),
108+
# cubic, large mixed shift
109+
(torch.eye(3, dtype=torch.float64) * 2.0, [5, 0, -10]),
110+
# highly tilted cell
111+
(([[1.3, 0.9, 0.8], [0.0, 1.0, 0.9], [0.0, 0.0, 1.0]]), [1, -2, 3]),
112+
# Left-handed cell
113+
(([[2.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 2.0]]), [1, 1, 1]),
114+
],
115+
)
116+
def test_pbc_wrap_general_param(cell: torch.Tensor, shift: torch.Tensor) -> None:
117+
"""Test periodic boundary wrapping for various cells and integer shifts."""
118+
cell = torch.as_tensor(cell, dtype=torch.float64)
119+
shift = torch.as_tensor(shift, dtype=torch.float64)
120+
base_frac = torch.tensor([[0.25, 0.5, 0.75]], dtype=torch.float64)
121+
base_cart = base_frac @ cell.T
122+
shifted_cart = base_cart + (shift @ cell.T)
123+
wrapped = tst.pbc_wrap_general(shifted_cart, cell)
124+
torch.testing.assert_close(wrapped, base_cart, rtol=1e-6, atol=1e-6)
121125

122126

123127
def test_pbc_wrap_general_edge_case() -> None:
@@ -277,35 +281,36 @@ def test_pbc_wrap_batched_orthorhombic(si_double_sim_state: SimState) -> None:
277281

278282
def test_pbc_wrap_batched_triclinic(device: torch.device) -> None:
279283
"""Test batched periodic boundary wrapping with triclinic cell."""
280-
# Create two triclinic cells with different tilt factors
284+
# Define cell matrices (M_row convention)
281285
cell1 = torch.tensor(
282286
[
283287
[2.0, 0.5, 0.0], # a vector with b-tilt
284288
[0.0, 2.0, 0.0], # b vector
285289
[0.0, 0.3, 2.0], # c vector with b-tilt
286290
],
291+
dtype=torch.float64,
287292
device=device,
288293
)
289-
290294
cell2 = torch.tensor(
291295
[
292296
[2.0, 0.0, 0.5], # a vector with c-tilt
293297
[0.3, 2.0, 0.0], # b vector with a-tilt
294298
[0.0, 0.0, 2.0], # c vector
295299
],
300+
dtype=torch.float64,
296301
device=device,
297302
)
303+
cell = torch.stack([cell1, cell2])
298304

299-
# Create positions for two atoms, one in each batch
305+
# Define positions (r_row convention)
300306
positions = torch.tensor(
301307
[
302-
[2.5, 2.5, 2.5], # First atom, outside batch 0's cell
303-
[2.7, 2.7, 2.7], # Second atom, outside batch 1's cell
308+
[2.5, 2.5, 2.5], # Atom 0 (batch 0)
309+
[2.7, 2.7, 2.7], # Atom 1 (batch 1)
304310
],
311+
dtype=torch.float64,
305312
device=device,
306313
)
307-
308-
# Create batch indices
309314
batch = torch.tensor([0, 1], device=device)
310315

311316
# Stack the cells for batched processing

torch_sim/integrators.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def position_step(state: MDState, dt: torch.Tensor) -> MDState:
168168

169169
if state.pbc:
170170
# Split positions and cells by batch
171-
new_positions = pbc_wrap_batched(
172-
new_positions, state.cell.swapaxes(1, 2), state.batch
173-
)
171+
new_positions = pbc_wrap_batched(new_positions, state.cell, state.batch)
174172

175173
state.positions = new_positions
176174
return state
@@ -1027,9 +1025,7 @@ def langevin_position_step(
10271025

10281026
# Apply periodic boundary conditions if needed
10291027
if state.pbc:
1030-
state.positions = pbc_wrap_batched(
1031-
state.positions, state.cell.swapaxes(1, 2), state.batch
1032-
)
1028+
state.positions = pbc_wrap_batched(state.positions, state.cell, state.batch)
10331029

10341030
return state
10351031

torch_sim/transforms.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,14 @@ def pbc_wrap_general(
9999
This implementation follows the general matrix-based approach for
100100
periodic boundary conditions in arbitrary triclinic cells:
101101
1. Transform positions to fractional coordinates using B = A^(-1)
102-
2. Wrap fractional coordinates to [0,1) using f - floor(f)
102+
2. Wrap fractional coordinates to [0,1) using modulo
103103
3. Transform back to real space using A
104104
105105
Args:
106106
positions (torch.Tensor): Tensor of shape (..., d)
107107
containing particle positions in real space.
108-
lattice_vectors (torch.Tensor): Tensor of shape (d, d)
109-
containing lattice vectors as columns (A matrix in the equations).
108+
lattice_vectors (torch.Tensor): Tensor of shape (d, d) containing
109+
lattice vectors as columns (A matrix in the equations).
110110
111111
Returns:
112112
torch.Tensor: Tensor of wrapped positions in real space with
@@ -124,23 +124,13 @@ def pbc_wrap_general(
124124
if positions.shape[-1] != lattice_vectors.shape[0]:
125125
raise ValueError("Position dimensionality must match lattice vectors.")
126126

127-
# Compute B = A^(-1) to transform to fractional coordinates
128-
B = torch.linalg.inv(lattice_vectors)
129-
130127
# Transform to fractional coordinates: f = Br
131-
frac_coords = positions @ B.T
132-
133-
# Wrap to reference cell [0,1) using f - floor(f)
134-
wrapped_frac = frac_coords - torch.floor(frac_coords)
128+
frac_coords = positions @ torch.linalg.inv(lattice_vectors).T
135129

136-
# Handle edge case of positions exactly on upper boundary
137-
wrapped_frac = torch.where(
138-
torch.isclose(wrapped_frac, torch.ones_like(wrapped_frac)),
139-
torch.zeros_like(wrapped_frac),
140-
wrapped_frac,
141-
)
130+
# Wrap to reference cell [0,1) using modulo
131+
wrapped_frac = frac_coords % 1.0
142132

143-
# Transform back to real space: t = Ag
133+
# Transform back to real space: r_row_wrapped = wrapped_f_row @ M_row
144134
return wrapped_frac @ lattice_vectors.T
145135

146136

@@ -157,7 +147,7 @@ def pbc_wrap_batched(
157147
positions (torch.Tensor): Tensor of shape (n_atoms, 3) containing
158148
particle positions in real space.
159149
cell (torch.Tensor): Tensor of shape (n_batches, 3, 3) containing
160-
lattice vectors for each batch.
150+
lattice vectors as column vectors.
161151
batch (torch.Tensor): Tensor of shape (n_atoms,) containing batch
162152
indices for each atom.
163153
@@ -191,15 +181,8 @@ def pbc_wrap_batched(
191181
# For each atom, multiply its position by its batch's inverse cell matrix
192182
frac_coords = torch.bmm(B_per_atom, positions.unsqueeze(2)).squeeze(2)
193183

194-
# Wrap to reference cell [0,1) using f - floor(f)
195-
wrapped_frac = frac_coords - torch.floor(frac_coords)
196-
197-
# Handle edge case of positions exactly on upper boundary
198-
wrapped_frac = torch.where(
199-
torch.isclose(wrapped_frac, torch.ones_like(wrapped_frac)),
200-
torch.zeros_like(wrapped_frac),
201-
wrapped_frac,
202-
)
184+
# Wrap to reference cell [0,1) using modulo
185+
wrapped_frac = frac_coords % 1.0
203186

204187
# Transform back to real space: r = A·f
205188
# Get the cell for each atom based on its batch index

0 commit comments

Comments
 (0)