Skip to content

Commit 4b81ad9

Browse files
rpachauricopybara-github
authored andcommitted
Add CPU/thread control to MSA tools and update dependencies.
Addresses the following pull request: * #358 PiperOrigin-RevId: 810894445 Change-Id: Ib424b3f331c5168f4505edd154b860420e89237e
1 parent 77816c7 commit 4b81ad9

File tree

5 files changed

+61
-16
lines changed

5 files changed

+61
-16
lines changed

alphafold/data/pipeline.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ class DataPipeline:
112112
"""Runs the alignment tools and assembles the input features."""
113113

114114
def __init__(self,
115+
*,
115116
jackhmmer_binary_path: str,
116117
hhblits_binary_path: str,
117118
uniref90_database_path: str,
@@ -124,23 +125,28 @@ def __init__(self,
124125
use_small_bfd: bool,
125126
mgnify_max_hits: int = 501,
126127
uniref_max_hits: int = 10000,
127-
use_precomputed_msas: bool = False):
128+
use_precomputed_msas: bool = False,
129+
msa_tools_n_cpu: int = 8):
128130
"""Initializes the data pipeline."""
129131
self._use_small_bfd = use_small_bfd
130132
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
131133
binary_path=jackhmmer_binary_path,
132-
database_path=uniref90_database_path)
134+
database_path=uniref90_database_path,
135+
n_cpu=msa_tools_n_cpu)
133136
if use_small_bfd:
134137
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
135138
binary_path=jackhmmer_binary_path,
136-
database_path=small_bfd_database_path)
139+
database_path=small_bfd_database_path,
140+
n_cpu=msa_tools_n_cpu)
137141
else:
138142
self.hhblits_bfd_uniref_runner = hhblits.HHBlits(
139143
binary_path=hhblits_binary_path,
140-
databases=[bfd_database_path, uniref30_database_path])
144+
databases=[bfd_database_path, uniref30_database_path],
145+
n_cpu=msa_tools_n_cpu)
141146
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
142147
binary_path=jackhmmer_binary_path,
143-
database_path=mgnify_database_path)
148+
database_path=mgnify_database_path,
149+
n_cpu=msa_tools_n_cpu)
144150
self.template_searcher = template_searcher
145151
self.template_featurizer = template_featurizer
146152
self.mgnify_max_hits = mgnify_max_hits

alphafold/data/pipeline_multimer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def add_assembly_features(
134134
# Group the chains by sequence
135135
seq_to_entity_id = {}
136136
grouped_chains = collections.defaultdict(list)
137-
for chain_id, chain_features in all_chain_features.items():
137+
for _, chain_features in all_chain_features.items():
138138
seq = str(chain_features['sequence'])
139139
if seq not in seq_to_entity_id:
140140
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
@@ -172,10 +172,12 @@ class DataPipeline:
172172

173173
def __init__(self,
174174
monomer_data_pipeline: pipeline.DataPipeline,
175+
*,
175176
jackhmmer_binary_path: str,
176177
uniprot_database_path: str,
177178
max_uniprot_hits: int = 50000,
178-
use_precomputed_msas: bool = False):
179+
use_precomputed_msas: bool = False,
180+
jackhmmer_n_cpu: int = 8):
179181
"""Initializes the data pipeline.
180182
181183
Args:
@@ -186,11 +188,13 @@ def __init__(self,
186188
will be searched with jackhmmer and used for MSA pairing.
187189
max_uniprot_hits: The maximum number of hits to return from uniprot.
188190
use_precomputed_msas: Whether to use pre-existing MSAs; see run_alphafold.
191+
jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer.
189192
"""
190193
self._monomer_data_pipeline = monomer_data_pipeline
191194
self._uniprot_msa_runner = jackhmmer.Jackhmmer(
192195
binary_path=jackhmmer_binary_path,
193-
database_path=uniprot_database_path)
196+
database_path=uniprot_database_path,
197+
n_cpu=jackhmmer_n_cpu)
194198
self._max_uniprot_hits = max_uniprot_hits
195199
self.use_precomputed_msas = use_precomputed_msas
196200

alphafold/data/tools/hhsearch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __init__(self,
3333
*,
3434
binary_path: str,
3535
databases: Sequence[str],
36-
maxseq: int = 1_000_000):
36+
maxseq: int = 1_000_000,
37+
cpu: int = 8):
3738
"""Initializes the Python HHsearch wrapper.
3839
3940
Args:
@@ -43,13 +44,15 @@ def __init__(self,
4344
_hhm.ffindex etc.)
4445
maxseq: The maximum number of rows in an input alignment. Note that this
4546
parameter is only supported in HHBlits version 3.1 and higher.
47+
cpu: The number of CPUs to use.
4648
4749
Raises:
4850
RuntimeError: If HHsearch binary not found within the path.
4951
"""
5052
self.binary_path = binary_path
5153
self.databases = databases
5254
self.maxseq = maxseq
55+
self.cpu = cpu
5356

5457
for database_path in self.databases:
5558
if not glob.glob(database_path + '_*'):
@@ -79,7 +82,8 @@ def query(self, a3m: str) -> str:
7982
cmd = [self.binary_path,
8083
'-i', input_path,
8184
'-o', hhr_path,
82-
'-maxseq', str(self.maxseq)
85+
'-maxseq', str(self.maxseq),
86+
'-cpu', str(self.cpu),
8387
] + db_cmd
8488

8589
logging.info('Launching subprocess "%s"', ' '.join(cmd))

alphafold/data/tools/hmmsearch.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ def __init__(self,
3333
binary_path: str,
3434
hmmbuild_binary_path: str,
3535
database_path: str,
36-
flags: Optional[Sequence[str]] = None):
36+
flags: Optional[Sequence[str]] = None,
37+
cpu: int = 8):
3738
"""Initializes the Python hmmsearch wrapper.
3839
3940
Args:
@@ -42,13 +43,15 @@ def __init__(self,
4243
an hmm from an input a3m.
4344
database_path: The path to the hmmsearch database (FASTA format).
4445
flags: List of flags to be used by hmmsearch.
46+
cpu: The number of CPUs to use for the hmmsearch query.
4547
4648
Raises:
4749
RuntimeError: If hmmsearch binary not found within the path.
4850
"""
4951
self.binary_path = binary_path
5052
self.hmmbuild_runner = hmmbuild.Hmmbuild(binary_path=hmmbuild_binary_path)
5153
self.database_path = database_path
54+
self.cpu = cpu
5255
if flags is None:
5356
# Default hmmsearch run settings.
5457
flags = ['--F1', '0.1',
@@ -89,7 +92,7 @@ def query_with_hmm(self, hmm: str) -> str:
8992
cmd = [
9093
self.binary_path,
9194
'--noali', # Don't include the alignment in stdout.
92-
'--cpu', '8'
95+
'--cpu', str(self.cpu),
9396
]
9497
# If adding flags, we have to do so before the output and input:
9598
if self.flags:

run_alphafold.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,30 @@ class ModelsToRelax(enum.Enum):
143143
'Relax on GPU can be much faster than CPU, so it is '
144144
'recommended to enable if possible. GPUs must be available'
145145
' if this setting is enabled.')
146+
flags.DEFINE_integer(
147+
'jackhmmer_n_cpu',
148+
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
149+
min(len(os.sched_getaffinity(0)), 8),
150+
'Number of CPUs to use for Jackhmmer. Defaults to min(cpu_count, 8). Going'
151+
' above 8 CPUs provides very little additional speedup.',
152+
lower_bound=0,
153+
)
154+
flags.DEFINE_integer(
155+
'hmmsearch_n_cpu',
156+
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
157+
min(len(os.sched_getaffinity(0)), 8),
158+
'Number of CPUs to use for HMMsearch. Defaults to min(cpu_count, 8). Going'
159+
' above 8 CPUs provides very little additional speedup.',
160+
lower_bound=0,
161+
)
162+
flags.DEFINE_integer(
163+
'hhsearch_n_cpu',
164+
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
165+
min(len(os.sched_getaffinity(0)), 8),
166+
'Number of CPUs to use for HHsearch. Defaults to min(cpu_count, 8). Going'
167+
' above 8 CPUs provides very little additional speedup.',
168+
lower_bound=0,
169+
)
146170

147171
FLAGS = flags.FLAGS
148172

@@ -464,7 +488,8 @@ def main(argv):
464488
template_searcher = hmmsearch.Hmmsearch(
465489
binary_path=FLAGS.hmmsearch_binary_path,
466490
hmmbuild_binary_path=FLAGS.hmmbuild_binary_path,
467-
database_path=FLAGS.pdb_seqres_database_path)
491+
database_path=FLAGS.pdb_seqres_database_path,
492+
cpu=FLAGS.hmmsearch_n_cpu)
468493
template_featurizer = templates.HmmsearchHitFeaturizer(
469494
mmcif_dir=FLAGS.template_mmcif_dir,
470495
max_template_date=FLAGS.max_template_date,
@@ -475,7 +500,8 @@ def main(argv):
475500
else:
476501
template_searcher = hhsearch.HHSearch(
477502
binary_path=FLAGS.hhsearch_binary_path,
478-
databases=[FLAGS.pdb70_database_path])
503+
databases=[FLAGS.pdb70_database_path],
504+
cpu=FLAGS.hhsearch_n_cpu)
479505
template_featurizer = templates.HhsearchHitFeaturizer(
480506
mmcif_dir=FLAGS.template_mmcif_dir,
481507
max_template_date=FLAGS.max_template_date,
@@ -495,15 +521,17 @@ def main(argv):
495521
template_searcher=template_searcher,
496522
template_featurizer=template_featurizer,
497523
use_small_bfd=use_small_bfd,
498-
use_precomputed_msas=FLAGS.use_precomputed_msas)
524+
use_precomputed_msas=FLAGS.use_precomputed_msas,
525+
msa_tools_n_cpu=FLAGS.jackhmmer_n_cpu)
499526

500527
if run_multimer_system:
501528
num_predictions_per_model = FLAGS.num_multimer_predictions_per_model
502529
data_pipeline = pipeline_multimer.DataPipeline(
503530
monomer_data_pipeline=monomer_data_pipeline,
504531
jackhmmer_binary_path=FLAGS.jackhmmer_binary_path,
505532
uniprot_database_path=FLAGS.uniprot_database_path,
506-
use_precomputed_msas=FLAGS.use_precomputed_msas)
533+
use_precomputed_msas=FLAGS.use_precomputed_msas,
534+
jackhmmer_n_cpu=FLAGS.jackhmmer_n_cpu)
507535
else:
508536
num_predictions_per_model = 1
509537
data_pipeline = monomer_data_pipeline

0 commit comments

Comments
 (0)