diff --git a/rfdiffusion/inference/model_runners.py b/rfdiffusion/inference/model_runners.py index f47d0e9..d51e11d 100644 --- a/rfdiffusion/inference/model_runners.py +++ b/rfdiffusion/inference/model_runners.py @@ -939,7 +939,28 @@ def sample_init(self): ### Get hotspots ### #################### self.hotspot_0idx=iu.get_idx0_hotspots(self.mappings, self.ppi_conf, self.binderlen) - + + ####################################### + ### Resolve cyclic peptide indicies ### + ####################################### + if self._conf.inference.cyclic: + if self._conf.inference.cyc_chains is None: + # default to all residues being cyclized + self.cyclic_reses = ~self.mask_str.to(self.device).squeeze() + else: + # use cyc_chains arg to determine cyclic_reses mask + assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string' + cyc_chains = self._conf.inference.cyc_chains + cyc_chains = [i.upper() for i in cyc_chains] + hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains + is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty + for ch in cyc_chains: + ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool() + is_cyclized[ch_mask] = True # set this whole chain to be cyclic + self.cyclic_reses = is_cyclized + else: + self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() + ######################### ### Set up potentials ### #########################