Skip to content

Commit a8e4ab0

Browse files
committed
Update all calculators to remove neighbors list workaround
1 parent 6f6d9bb commit a8e4ab0

File tree

3 files changed

+88
-34
lines changed

3 files changed

+88
-34
lines changed

rascaline/src/calculators/neighbor_list.rs

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ pub struct NeighborList {
3939
pub full_neighbor_list: bool,
4040
/// Should individual atoms be considered their own neighbor? Setting this
4141
/// to `true` will add "self pairs", i.e. pairs between an atom and itself,
42-
/// with the distance 0. The `pair_id` of such pairs is set to -1.
42+
/// with the distance 0.
4343
pub self_pairs: bool,
4444
}
4545

@@ -423,7 +423,8 @@ impl FullNeighborList {
423423
let cell_c = pair.cell_shift_indices[2];
424424

425425
if species_first == species_second {
426-
// same species for both atoms in the pair
426+
// same species for both atoms in the pair, add the pair
427+
// twice in both directions.
427428
if species[pair.first] == species_first.i32() && species[pair.second] == species_second.i32() {
428429
builder.add(&[
429430
LabelValue::from(system_i),
@@ -434,18 +435,14 @@ impl FullNeighborList {
434435
LabelValue::from(cell_c),
435436
]);
436437

437-
if pair.first != pair.second {
438-
// if the pair is between two different atoms,
439-
// also add the reversed (second -> first) pair.
440-
builder.add(&[
441-
LabelValue::from(system_i),
442-
LabelValue::from(pair.second),
443-
LabelValue::from(pair.first),
444-
LabelValue::from(-cell_a),
445-
LabelValue::from(-cell_b),
446-
LabelValue::from(-cell_c),
447-
]);
448-
}
438+
builder.add(&[
439+
LabelValue::from(system_i),
440+
LabelValue::from(pair.second),
441+
LabelValue::from(pair.first),
442+
LabelValue::from(-cell_a),
443+
LabelValue::from(-cell_b),
444+
LabelValue::from(-cell_c),
445+
]);
449446
}
450447
} else {
451448
// different species, find the right order for the pair
@@ -501,6 +498,11 @@ impl FullNeighborList {
501498
let species = system.species()?;
502499

503500
for pair in system.pairs()? {
501+
if pair.first == pair.second {
502+
// self pairs should not be part of the neighbors list
503+
assert_ne!(pair.cell_shift_indices, [0, 0, 0]);
504+
}
505+
504506
let first_block_i = descriptor.keys().position(&[
505507
species[pair.first].into(), species[pair.second].into()
506508
]);
@@ -565,11 +567,6 @@ impl FullNeighborList {
565567
}
566568
}
567569

568-
if pair.first == pair.second {
569-
// do not duplicate self pairs
570-
continue;
571-
}
572-
573570
// then the pair second -> first
574571
if let Some(second_block_i) = second_block_i {
575572
let mut block = descriptor.block_mut_by_id(second_block_i);
@@ -764,6 +761,75 @@ mod tests {
764761
assert_relative_eq!(array, expected, max_relative=1e-6);
765762
}
766763

764+
#[test]
765+
fn periodic_neighbor_list() {
766+
let mut calculator = Calculator::from(Box::new(NeighborList{
767+
cutoff: 12.0,
768+
full_neighbor_list: false,
769+
self_pairs: false,
770+
}) as Box<dyn CalculatorBase>);
771+
772+
let mut systems = test_systems(&["CH"]);
773+
774+
let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();
775+
assert_eq!(*descriptor.keys(), Labels::new(
776+
["species_first_atom", "species_second_atom"],
777+
&[[1, 1], [1, 6], [6, 6]]
778+
));
779+
780+
// H-H block
781+
let block = descriptor.block_by_id(0);
782+
assert_eq!(block.samples(), Labels::new(
783+
["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
784+
// the pairs only differ in cell shifts
785+
&[[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 1, 0], [0, 1, 1, 1, 0, 0]]
786+
));
787+
788+
let array = block.values().to_array();
789+
let expected = &ndarray::arr3(&[
790+
[[0.0], [0.0], [10.0]],
791+
[[0.0], [10.0], [0.0]],
792+
[[10.0], [0.0], [0.0]],
793+
]).into_dyn();
794+
assert_relative_eq!(array, expected, max_relative=1e-6);
795+
796+
// now a full NL
797+
let mut calculator = Calculator::from(Box::new(NeighborList{
798+
cutoff: 12.0,
799+
full_neighbor_list: true,
800+
self_pairs: false,
801+
}) as Box<dyn CalculatorBase>);
802+
803+
let descriptor = calculator.compute(&mut systems, Default::default()).unwrap();
804+
assert_eq!(*descriptor.keys(), Labels::new(
805+
["species_first_atom", "species_second_atom"],
806+
&[[1, 1], [1, 6], [6, 1], [6, 6]]
807+
));
808+
809+
// H-H block
810+
let block = descriptor.block_by_id(0);
811+
assert_eq!(block.samples(), Labels::new(
812+
["structure", "first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"],
813+
// twice as many pairs
814+
&[
815+
[0, 1, 1, 0, 0, 1], [0, 1, 1, 0, 0, -1],
816+
[0, 1, 1, 0, 1, 0], [0, 1, 1, 0, -1, 0],
817+
[0, 1, 1, 1, 0, 0], [0, 1, 1, -1, 0, 0],
818+
]
819+
));
820+
821+
let array = block.values().to_array();
822+
let expected = &ndarray::arr3(&[
823+
[[0.0], [0.0], [10.0]],
824+
[[0.0], [0.0], [-10.0]],
825+
[[0.0], [10.0], [0.0]],
826+
[[0.0], [-10.0], [0.0]],
827+
[[10.0], [0.0], [0.0]],
828+
[[-10.0], [0.0], [0.0]],
829+
]).into_dyn();
830+
assert_relative_eq!(array, expected, max_relative=1e-6);
831+
}
832+
767833
#[test]
768834
fn finite_differences_positions() {
769835
// half neighbor list

rascaline/src/calculators/soap/spherical_expansion.rs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,12 +241,6 @@ impl SphericalExpansion {
241241
}
242242
}
243243

244-
if pair.first == pair.second {
245-
// do not compute for the reversed pair if the pair is
246-
// between an atom and its image
247-
continue;
248-
}
249-
250244
if let Some(mapped_center) = result.centers_mapping[pair.second] {
251245
// add the pair contribution to the atomic environnement
252246
// corresponding to the **second** atom in the pair
@@ -778,7 +772,7 @@ mod tests {
778772

779773
fn parameters() -> SphericalExpansionParameters {
780774
SphericalExpansionParameters {
781-
cutoff: 3.5,
775+
cutoff: 7.8,
782776
max_radial: 6,
783777
max_angular: 6,
784778
atomic_gaussian_width: 0.3,

rascaline/src/calculators/soap/spherical_expansion_pair.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -755,13 +755,7 @@ impl CalculatorBase for SphericalExpansionByPair {
755755
}
756756
}
757757

758-
// also check for the block with a reversed pair, except if
759-
// we are handling a pair between an atom and it's own
760-
// periodic image
761-
if pair.first == pair.second {
762-
continue;
763-
}
764-
758+
// also check for the block with a reversed pair
765759
contribution.inverse_pair(&self.m_1_pow_l);
766760

767761
for spherical_harmonics_l in 0..=self.parameters.max_angular {
@@ -817,7 +811,7 @@ mod tests {
817811

818812
fn parameters() -> SphericalExpansionParameters {
819813
SphericalExpansionParameters {
820-
cutoff: 3.5,
814+
cutoff: 7.3,
821815
max_radial: 6,
822816
max_angular: 6,
823817
atomic_gaussian_width: 0.3,

0 commit comments

Comments
 (0)