Skip to content

Add method to write new snapshot that contains a subset of another snapshot's molecules. #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions cmeutils/gsd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,131 @@ def xml_to_gsd(xmlfile, gsdfile):
print(f"XML data written to {gsdfile}")


def trim_snapshot_molecules(parent_snapshot, mol_indices):
"""Given a snapshot of a system, trim the snapshot to only include
a subset of the molecules.

Parameters
----------
parent_snapshot : gsd.hoomd.Frame
The snapshot to read in.
mol_indices : list of np.ndarray
List of arrays where each array contains the indices
of the particles in a molecule to include.

Returns
-------
gsd.hoomd.Frame
The new snapshot with only the specified molecules.

Notes
-----
See cmetuils.gsd_utils.get_molecule_cluster for a method to obtain
mol_indices.

"""
new_snap = gsd.hoomd.Frame()
new_snap.configuration.box = parent_snapshot.configuration.box
new_snap.particles.N = sum(len(i) for i in mol_indices)

# Write out particle info
for attr in [
"position",
"mass",
"velocity",
"orientation",
"image",
"diameter",
"angmom",
"typeid",
]:
setattr(
new_snap.particles,
attr,
np.concatenate(
list(
getattr(parent_snapshot.particles, attr)[i]
for i in mol_indices
)
),
)
new_snap.particles.types = parent_snapshot.particles.types

particle_index_map = dict()
count = 0
for indices in mol_indices:
for i in indices:
particle_index_map[i] = count
count += 1

# Write out bond info
mol_bond_groups = []
mol_bond_ids = []
for indices in mol_indices:
mask = np.any(
np.isin(parent_snapshot.bonds.group, indices.flatten()), axis=1
)
parent_mol_bonds = parent_snapshot.bonds.group[np.where(mask)[0]]
parent_mol_bond_typeids = parent_snapshot.bonds.typeid[
np.where(mask)[0]
]
new_mol_bonds = np.vectorize(particle_index_map.get)(parent_mol_bonds)
mol_bond_groups.append(new_mol_bonds)
mol_bond_ids.append(parent_mol_bond_typeids)

new_snap.bonds.types = parent_snapshot.bonds.types
new_snap.bonds.group = np.concatenate(mol_bond_groups)
new_snap.bonds.typeid = np.concatenate(mol_bond_ids)
new_snap.bonds.N = sum(len(i) for i in mol_bond_ids)

# Write out angle info
mol_angle_groups = []
mol_angle_ids = []
for indices in mol_indices:
mask = np.any(
np.isin(parent_snapshot.angles.group, indices.flatten()), axis=1
)
parent_mol_angles = parent_snapshot.angles.group[np.where(mask)[0]]
parent_mol_angle_typeids = parent_snapshot.angles.typeid[
np.where(mask)[0]
]
new_mol_angles = np.vectorize(particle_index_map.get)(parent_mol_angles)
mol_angle_groups.append(new_mol_angles)
mol_angle_ids.append(parent_mol_angle_typeids)

new_snap.angles.types = parent_snapshot.angles.types
new_snap.angles.group = np.concatenate(mol_angle_groups)
new_snap.angles.typeid = np.concatenate(mol_angle_ids)
new_snap.angles.N = sum(len(i) for i in mol_angle_ids)

# Write out dihedral info
mol_dihedral_groups = []
mol_dihedral_ids = []
for indices in mol_indices:
mask = np.any(
np.isin(parent_snapshot.dihedrals.group, indices.flatten()), axis=1
)
parent_mol_dihedrals = parent_snapshot.dihedrals.group[
np.where(mask)[0]
]
parent_mol_dihedral_typeids = parent_snapshot.dihedrals.typeid[
np.where(mask)[0]
]
new_mol_dihedrals = np.vectorize(particle_index_map.get)(
parent_mol_dihedrals
)
mol_dihedral_groups.append(new_mol_dihedrals)
mol_dihedral_ids.append(parent_mol_dihedral_typeids)

new_snap.dihedrals.types = parent_snapshot.dihedrals.types
new_snap.dihedrals.group = np.concatenate(mol_dihedral_groups)
new_snap.dihedrals.typeid = np.concatenate(mol_dihedral_ids)
new_snap.dihedrals.N = sum(len(i) for i in mol_dihedral_ids)

new_snap.validate()
return new_snap


def identify_snapshot_connections(snapshot):
"""Identify angle and dihedral connections in a snapshot from bonds.

Expand Down