diff --git a/notebooks/batch_inference.py b/notebooks/batch_inference.py new file mode 100644 index 0000000..10f9b7f --- /dev/null +++ b/notebooks/batch_inference.py @@ -0,0 +1,126 @@ +import warnings +import argparse +import os +import json +import time +from tqdm import tqdm +from unifold.colab.model import colab_inference +from unifold.colab.data import validate_input, get_features +warnings.filterwarnings("ignore") + +MIN_SINGLE_SEQUENCE_LENGTH = 6 +MAX_SINGLE_SEQUENCE_LENGTH = 5000 +MAX_MULTIMER_LENGTH = 5000 + + +def process_batch_json(tasks, jobname, output_dir_base): + if isinstance(tasks, dict): + new_tasks = [] + for k, v in tasks.items(): + v['id'] = k + new_tasks.append(v) + tasks = new_tasks + + # check the input. + for idx, task in enumerate(tasks): + if 'id' not in task.keys(): + task['id'] = idx + + if 'sequence' not in task.keys(): + raise KeyError(f"number {idx+1}-th 'sequence' not found in dict keys: {task.keys()} in json.") + + target_id = f"{jobname}_{task['id']}" + input_sequences = task['sequence'].strip().split(';') + + task['target_id'] = target_id + + if 'symmetry' not in task.keys(): + task['symmetry'] = 'C1' + + symmetry_group = task['symmetry'] + # check the sequences + sequences, is_multimer, symmetry_group = validate_input( + input_sequences=input_sequences, + symmetry_group=symmetry_group, + min_length=MIN_SINGLE_SEQUENCE_LENGTH, + max_length=MAX_SINGLE_SEQUENCE_LENGTH, + max_multimer_length=MAX_MULTIMER_LENGTH) + task['is_multimer'] = is_multimer + + # save features to `output_dir_base` + feature_output_dir = get_features( + jobname=jobname, + target_id=target_id, + sequences=sequences, + output_dir_base=output_dir_base, + is_multimer=is_multimer, + msa_mode=args.msa_mode, + use_templates=True if args.use_templates > 0 else False + ) + + task['feature_output_dir'] = feature_output_dir + task['symmetry'] = task['symmetry'] if task['symmetry'] != 'C1' else None + + return tasks + +def manual_operations(): + # developers may operate on the pickle files here + # to customize the features for inference. + pass + +manual_operations() + + +def main(args): + output_dir_base = args.out_dir + os.makedirs(output_dir_base, exist_ok=True) + + input_json_path = args.input_json + with open(input_json_path, encoding="utf-8") as fp: + input_json = json.load(fp) + + all_tasks = process_batch_json(input_json, args.jobname, output_dir_base) + + for task in tqdm(all_tasks, desc='running Unifold'): + start = time.time() + best_result = colab_inference( + target_id=task['target_id'], + data_dir=task['feature_output_dir'], + param_dir='.', + output_dir=task['feature_output_dir'], + symmetry_group=task['symmetry'], + is_multimer=task['is_multimer'], + max_recycling_iters=args.max_recycling_iters, + num_ensembles=args.num_ensembles, + times=args.times, + manual_seed=args.manual_seed, + device=args.device, # do not change this on colab. + bf16=args.bf16 + ) + + task['best_plddt'] = best_result['plddt'].mean().item() + task['pae'] = best_result['pae'].mean().item() if best_result['pae'] is not None else None + task['best_results_path'] = best_result['best_results_path'] + task['run_time'] = (time.time() - start)/60 + + # incase oom + with open(os.path.join(output_dir_base, 'all_tasks_summary.json'), 'w') as f: + json.dump(all_tasks, f, indent=2) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('-i', '--input_json', type=str, required=True) + parser.add_argument('-o', '--out_dir', type=str, default="predictions") + parser.add_argument('--jobname', type=str, default="jobname") + parser.add_argument('--msa_mode', type=str, default="MMseqs2", choices=["MMseqs2","single_sequence"]) + parser.add_argument('--num_ensembles', type=int, default=2) + parser.add_argument('--max_recycling_iters', type=int, default=3) + parser.add_argument('--times', type=int, default=1) + parser.add_argument('--use_templates', type=int, default=1) + parser.add_argument('--manual_seed', type=int, default=42) + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--bf16', action='store_true') + args = parser.parse_args() + print(args) + main(args) \ No newline at end of file diff --git a/notebooks/unifold_batch.ipynb b/notebooks/unifold_batch.ipynb new file mode 100644 index 0000000..0d74549 --- /dev/null +++ b/notebooks/unifold_batch.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Uni-Fold Batch Inference Notebook\n", + "\n", + "This notebook provides protein structure prediction service of [Uni-Fold](https://github.com/dptech-corp/Uni-Fold/) as well as [UF-Symmetry](https://www.biorxiv.org/content/10.1101/2022.08.30.505833v1). Predictions of both protein monomers and multimers are supported. The homology search process in this notebook is enabled with the [MMSeqs2](https://github.com/soedinglab/MMseqs2.git) server provided by [ColabFold](https://github.com/sokrypton/ColabFold). For more consistent results with the original AlphaFold(-Multimer), please refer to the open-source repository of [Uni-Fold](https://github.com/dptech-corp/Uni-Fold/), or our convenient web server at [Hermite™](https://hermite.dp.tech/).\n", + "\n", + "Please note that this notebook is provided as an early-access prototype, and is NOT an official product of DP Technology. It is provided for theoretical modeling only and caution should be exercised in its use. \n", + "\n", + "**Licenses**\n", + "\n", + "This Colab uses the [Uni-Fold model parameters](https://github.com/dptech-corp/Uni-Fold/#model-parameters-license) and its outputs are under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode. The Colab itself is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0).\n", + "\n", + "\n", + "**Citations**\n", + "\n", + "Please cite the following papers if you use this notebook:\n", + "\n", + "* Ziyao Li, Xuyang Liu, Weijie Chen, Fan Shen, Hangrui Bi, Guolin Ke, Linfeng Zhang. \"[Uni-Fold: An Open-Source Platform for Developing Protein Folding Models beyond AlphaFold.](https://www.biorxiv.org/content/10.1101/2022.08.04.502811v1)\" biorxiv (2022)\n", + "* Ziyao Li, Shuwen Yang, Xuyang Liu, Weijie Chen, Han Wen, Fan Shen, Guolin Ke, Linfeng Zhang. \"[Uni-Fold Symmetry: Harnessing Symmetry in Folding Large Protein Complexes.](https://www.biorxiv.org/content/10.1101/2022.08.30.505833v1)\" bioRxiv (2022)\n", + "* Mirdita M, Schütze K, Moriwaki Y, Heo L, Ovchinnikov S and Steinegger M. \"[ColabFold: Making protein folding accessible to all.](https://www.nature.com/articles/s41592-022-01488-1)\" Nature Methods (2022)\n", + "\n", + "**Acknowledgements**\n", + "\n", + "The model architecture of Uni-Fold is largely based on [AlphaFold](https://doi.org/10.1038/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1). The design of this notebook refers directly to [ColabFold](https://www.nature.com/articles/s41592-022-01488-1). We specially thank [@sokrypton](https://twitter.com/sokrypton) for his helpful suggestions to this notebook.\n", + "\n", + "Copyright © 2022 DP Technology. All rights reserved." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "import os\n", + "import json\n", + "from unifold.colab.data import validate_input, get_features\n", + "\n", + "is_colab = False\n", + "#@title Provide the arguments here and hit `Run` -> `Run All Cells`\n", + "jobname = 'unifold_batch_colab' #@param {type:\"string\"}\n", + "use_templates = True #@param {type:\"boolean\"}\n", + "msa_mode = \"MMseqs2\" #@param [\"MMseqs2\",\"single_sequence\"]\n", + "#@markdown Parameters for model inference.\n", + "max_recycling_iters = 3 #@param {type:\"integer\"}\n", + "num_ensembles = 2 #@param {type:\"integer\"}\n", + "manual_seed = 42 #@param {type:\"integer\"}\n", + "times = 1 #@param {type:\"integer\"}\n", + "#@markdown Plotting parameters.\n", + "show_sidechains = False #@param {type:\"boolean\"}\n", + "dpi = 100 #@param {type:\"integer\"}\n", + "max_display_cnt = 3 #@param {type:\"integer\"}\n", + "\n", + "MIN_SINGLE_SEQUENCE_LENGTH = 6\n", + "MAX_SINGLE_SEQUENCE_LENGTH = 3000\n", + "MAX_MULTIMER_LENGTH = 3000" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Install Uni-Fold and third-party softwares\n", + "#@markdown Please execute this cell by pressing the _Play_ button \n", + "#@markdown on the left to download and import third-party software \n", + "#@markdown in this Colab notebook. (See the [acknowledgements](https://github.com/dptech-corp/Uni-Fold/#acknowledgements) in our readme.)\n", + "\n", + "#@markdown **Note**: This installs the software on the Colab \n", + "#@markdown notebook in the cloud and not on your computer.\n", + "%%bash\n", + "if [ ! -f ENV_READY ]; then\n", + " apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y -qq kalign\n", + "\n", + " # Install HHsuite.\n", + " wget -q https://github.com/soedinglab/hh-suite/releases/download/v3.3.0/hhsuite-3.3.0-AVX2-Linux.tar.gz; tar xfz hhsuite-3.3.0-AVX2-Linux.tar.gz; ln -s $(pwd)/bin/* /usr/bin \n", + "\n", + " pip3 -q install py3dmol gdown\n", + "\n", + " pip3 -q install libmsym\n", + "\n", + " touch ENV_READY\n", + "fi\n", + "\n", + "GIT_REPO='https://github.com/dptech-corp/Uni-Fold'\n", + "UNICORE_URL='https://github.com/dptech-corp/Uni-Core/releases/download/0.0.2/unicore-0.0.1+cu118torch2.0.0-cp310-cp310-linux_x86_64.whl'\n", + "PARAM_URL='https://github.com/dptech-corp/Uni-Fold/releases/download/v2.0.0/unifold_params_2022-08-01.tar.gz'\n", + "UF_SYMM_PARAM_URL='https://github.com/dptech-corp/Uni-Fold/releases/download/v2.2.0/uf_symmetry_params_2022-09-06.tar.gz'\n", + "\n", + "if [ ! -f UNIFOLD_READY ]; then\n", + " wget ${UNICORE_URL} \n", + " pip3 -q install \"unicore-0.0.1+cu118torch2.0.0-cp310-cp310-linux_x86_64.whl\"\n", + " git clone -b main ${GIT_REPO}\n", + " pip3 -q install ./Uni-Fold\n", + " wget ${PARAM_URL}\n", + " tar -xzf \"unifold_params_2022-08-01.tar.gz\"\n", + " wget ${UF_SYMM_PARAM_URL}\n", + " tar -xzf \"uf_symmetry_params_2022-09-06.tar.gz\"\n", + "\n", + " touch UNIFOLD_READY\n", + "fi" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CONFIGURATION\n", + "Set up input contents (from file or directly filling `input_json`) and output path.\n", + "- `jobname (str)`: name of the job, served as prefix of output directories.\n", + "- `input_json_path (str)`: path of input json file, which contains a list or dict of proteins. *If it's a list, we take indices as IDs.* Each protein is a dict with keys:\n", + " - `symmetry`: protein's symmetry group. Use \"C1\" as default.\n", + " - `sequence`: the sequences of the asymmetric unit (splitted by \";\").\n", + " - `id` is optional. if not existed, it will be the order of the sequences.\n", + " - other thing you can add.\n", + "- `output_dir_base (str)`: root directory of output files.\n", + "\n", + "\n", + "examples of `list`:\n", + "```python\n", + "input_json = [\n", + " {'sequence': 'MGSSHHHHHHSSGLVPRGSHMEDRDPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEFFKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG'},\n", + " {'symmetry': 'C2', 'sequence': 'GSHMKNVLIGVQTNLGVNKTGTEFGPDDLIQAYPDTFDEMELISVERQKEDFNDKKLKFKNTVLDTCEKIAKRVNEAVIDGYRPILVGGDHSISLGSVSGVSLEKEIGVLWISAHGDMNTPESTLTGNIHGMPLALLQGLGDRELVNCFYEGAKLDSRNIVIFGAREIEVEERKIIEKTGVKIVYYDDILRKGIDNVLDEVKDYLKIDNLHISIDMNVFDPEIAPGVSVPVRRGMSYDEMFKSLKFAFKNYSVTSADITEFNPLNDINGKTAELVNGIVQYMMNPDY'},\n", + " {'symmetry': 'C2', 'sequence': 'GGSGGSGGSGGSLFCEQVTTVTNLFEKWNDCERTVVMYALLKRLRYPSLKFLQYSIDSNLTQNLGTSQTNLSSVVIDINANNPVYLQNLLNAYKTARKEDILHEVLNMLPLLKPGNEEAKLIYLTLIPVAVKDTMQQIVPTELVQQIFSYLLIHPAITSEDRRSLNIWLRHLEDHIQ;SVPSYGEDELQQAMRLLNAASRQRTEAANEDFGGT'},\n", + " {'symmetry': 'C3', 'sequence': 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVCTVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'},\n", + " ]\n", + "```\n", + "\n", + "Another `dict` case is showed as followed:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_dir_base = \"./prediction\" if is_colab else \"./prediction\" #@param {type:\"string\"}\n", + "os.makedirs(output_dir_base, exist_ok=True)\n", + "\n", + "input_json_path = 'your_json_path.json'\n", + "\n", + "\n", + "if os.path.isfile(input_json_path):\n", + " with open(input_json_path, encoding=\"utf-8\") as fp:\n", + " input_json = json.load(fp)\n", + " default_list_case = False\n", + " default_dict_case = False\n", + "else: # A DEMO CASE (DICT). list case is above.\n", + " input_json = {\n", + " '7teu': {'sequence': 'MGSSHHHHHHSSGLVPRGSHMEDRDPTQFEERHLKFLQQLGKGNFGSVEMCRYDPLQDNTGEVVAVKKLQHSTEEHLRDFEREIEILKSLQHDNIVKYKGVCYSAGRRNLKLIMEYLPYGSLRDYLQKHKERIDHIKLLQYTSQICKGMEYLGTKRYIHRDLATRNILVENENRVKIGDFGLTKVLPQDKEFFKVKEPGESPIFWYAPESLTESKFSVASDVWSFGVVLYELFTYIEKSKSPPAEFMRMIGNDKQGQMIVFHLIELLKNNGRLPRPDGCPDEIYMIMTECWNNNVNQRPSFRDLALRVDQIRDNMAG'},\n", + " '8d27': {'symmetry': 'C2', 'sequence': 'GSHMKNVLIGVQTNLGVNKTGTEFGPDDLIQAYPDTFDEMELISVERQKEDFNDKKLKFKNTVLDTCEKIAKRVNEAVIDGYRPILVGGDHSISLGSVSGVSLEKEIGVLWISAHGDMNTPESTLTGNIHGMPLALLQGLGDRELVNCFYEGAKLDSRNIVIFGAREIEVEERKIIEKTGVKIVYYDDILRKGIDNVLDEVKDYLKIDNLHISIDMNVFDPEIAPGVSVPVRRGMSYDEMFKSLKFAFKNYSVTSADITEFNPLNDINGKTAELVNGIVQYMMNPDY'},\n", + " '8oij': {'symmetry': 'C2', 'sequence': 'GGSGGSGGSGGSLFCEQVTTVTNLFEKWNDCERTVVMYALLKRLRYPSLKFLQYSIDSNLTQNLGTSQTNLSSVVIDINANNPVYLQNLLNAYKTARKEDILHEVLNMLPLLKPGNEEAKLIYLTLIPVAVKDTMQQIVPTELVQQIFSYLLIHPAITSEDRRSLNIWLRHLEDHIQ;SVPSYGEDELQQAMRLLNAASRQRTEAANEDFGGT'},\n", + " 'c2404': {'symmetry': 'C3', 'sequence': 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVCTVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'},\n", + " }\n", + "\n", + "\n", + "def process_batch_json(tasks, jobname):\n", + " if isinstance(tasks, dict):\n", + " new_tasks = []\n", + " for k, v in tasks.items():\n", + " v['id'] = k\n", + " new_tasks.append(v)\n", + " tasks = new_tasks\n", + " \n", + " # check the input.\n", + " for idx, task in enumerate(tasks):\n", + " if 'id' not in task.keys():\n", + " task['id'] = idx\n", + " \n", + " if 'sequence' not in task.keys():\n", + " raise KeyError(f\"number {idx+1}-th 'sequence' not found in dict keys: {task.keys()} in json.\")\n", + " \n", + " target_id = f\"{jobname}_{task['id']}\"\n", + " input_sequences = task['sequence'].strip().split(';')\n", + " \n", + " task['target_id'] = target_id\n", + " \n", + " if 'symmetry' not in task.keys():\n", + " task['symmetry'] = 'C1'\n", + " \n", + " symmetry_group = task['symmetry'] \n", + " # check the sequences\n", + " sequences, is_multimer, symmetry_group = validate_input(\n", + " input_sequences=input_sequences,\n", + " symmetry_group=symmetry_group,\n", + " min_length=MIN_SINGLE_SEQUENCE_LENGTH,\n", + " max_length=MAX_SINGLE_SEQUENCE_LENGTH,\n", + " max_multimer_length=MAX_MULTIMER_LENGTH)\n", + " task['is_multimer'] = is_multimer\n", + " \n", + " # save features to `output_dir_base`\n", + " feature_output_dir = get_features(\n", + " jobname=jobname,\n", + " target_id=target_id,\n", + " sequences=sequences,\n", + " output_dir_base=output_dir_base,\n", + " is_multimer=is_multimer,\n", + " msa_mode=msa_mode,\n", + " use_templates=use_templates\n", + " )\n", + " \n", + " task['feature_output_dir'] = feature_output_dir\n", + " task['symmetry'] = task['symmetry'] if task['symmetry'] != 'C1' else None\n", + "\n", + " return tasks\n", + "\n", + "\n", + "all_tasks = process_batch_json(input_json, jobname)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Uni-Fold prediction on GPU.\n", + "import time\n", + "from tqdm import tqdm\n", + "from unifold.colab.model import colab_inference\n", + "\n", + "def manual_operations():\n", + " # developers may operate on the pickle files here\n", + " # to customize the features for inference.\n", + " pass\n", + "\n", + "manual_operations()\n", + "\n", + "for task in tqdm(all_tasks, desc='running Unifold'):\n", + " start = time.time()\n", + " best_result = colab_inference(\n", + " target_id=task['target_id'],\n", + " data_dir=task['feature_output_dir'],\n", + " param_dir='.',\n", + " output_dir=task['feature_output_dir'],\n", + " symmetry_group=task['symmetry'],\n", + " is_multimer=task['is_multimer'],\n", + " max_recycling_iters=max_recycling_iters,\n", + " num_ensembles=num_ensembles,\n", + " times=times,\n", + " manual_seed=manual_seed,\n", + " device=\"cuda:0\", # do not change this on colab.\n", + " )\n", + " \n", + " task['best_plddt'] = best_result['plddt'].mean().item()\n", + " task['pae'] = best_result['pae'].mean().item() if best_result['pae'] is not None else None\n", + " task['best_results_path'] = best_result['best_results_path']\n", + " task['protein'] = best_result['protein']\n", + " task['run_time'] = time.time() - start\n", + " print(f\"total time: {time.time() - start}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "task_best_proteins = []\n", + "with open(os.path.join(output_dir_base, 'all_tasks_summary.json'), 'w') as f:\n", + " # remove the protein for clean resluts config.\n", + " for item in all_tasks:\n", + " if 'protein' in item:\n", + " protein = item.pop('protein')\n", + " task_best_proteins.append({'id':item['id'], 'protein': protein})\n", + " json.dump(all_tasks, f, indent=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the one protein structure, select one to display.\n", + "display_cases_number = -1\n", + "\n", + "from unifold.colab.plot import colab_plot\n", + "\n", + "colab_plot(\n", + " best_result=task_best_proteins[display_cases_number],\n", + " output_dir=task_best_proteins[display_cases_number],\n", + " show_sidechains=show_sidechains,\n", + " dpi=dpi,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/unifold/colab/data.py b/unifold/colab/data.py index c317b4a..f880fb7 100644 --- a/unifold/colab/data.py +++ b/unifold/colab/data.py @@ -1,7 +1,14 @@ import hashlib import os from typing import Dict, List, Sequence, Tuple, Union, Any, Optional - +import pickle +import gzip +from pathlib import Path +from unifold.msa import pipeline, parsers +from unifold.data.protein import PDB_CHAIN_IDS +from unifold.data.utils import compress_features +from unifold.msa.utils import divide_multi_chains +from unifold.colab.mmseqs import get_msa_and_templates from unifold.data import residue_constants, protein from unifold.msa.utils import divide_multi_chains @@ -123,3 +130,84 @@ def load_feature_for_one_target( ) batch = UnifoldDataset.collater([batch]) return batch + + +def get_features( + jobname: str, + target_id: str, + sequences: List[str], + output_dir_base: str, + is_multimer: bool, + msa_mode: str, + use_templates: str, + ): + + # Validate the input. + + descriptions = ['> '+target_id+' seq'+str(ii) for ii in range(len(sequences))] + + if is_multimer: + divide_multi_chains(target_id, output_dir_base, sequences, descriptions) + + s = [] + for des, seq in zip(descriptions, sequences): + s += [des, seq] + + unique_sequences = [] + [unique_sequences.append(x) for x in sequences if x not in unique_sequences] + + if len(unique_sequences)==1: + homooligomers_num = len(sequences) + else: + homooligomers_num = 1 + + with open(f"{output_dir_base}/{jobname}.fasta", "w") as f: + f.write("\n".join(s)) + + result_dir = Path(output_dir_base) + output_dir = os.path.join(output_dir_base, target_id) + + ( + unpaired_msa, + paired_msa, + template_results, + ) = get_msa_and_templates( + target_id, + unique_sequences, + result_dir=result_dir, + msa_mode=msa_mode, + use_templates=use_templates, + homooligomers_num = homooligomers_num + ) + + for idx, seq in enumerate(unique_sequences): + chain_id = PDB_CHAIN_IDS[idx] + sequence_features = pipeline.make_sequence_features( + sequence=seq, description=f'> {jobname} seq {chain_id}', num_res=len(seq) + ) + monomer_msa = parsers.parse_a3m(unpaired_msa[idx]) + msa_features = pipeline.make_msa_features([monomer_msa]) + template_features = template_results[idx] + feature_dict = {**sequence_features, **msa_features, **template_features} + feature_dict = compress_features(feature_dict) + features_output_path = os.path.join( + output_dir, "{}.feature.pkl.gz".format(chain_id) + ) + pickle.dump( + feature_dict, + gzip.GzipFile(features_output_path, "wb"), + protocol=4 + ) + if is_multimer: + multimer_msa = parsers.parse_a3m(paired_msa[idx]) + pair_features = pipeline.make_msa_features([multimer_msa]) + pair_feature_dict = compress_features(pair_features) + uniprot_output_path = os.path.join( + output_dir, "{}.uniprot.pkl.gz".format(chain_id) + ) + pickle.dump( + pair_feature_dict, + gzip.GzipFile(uniprot_output_path, "wb"), + protocol=4, + ) + return output_dir \ No newline at end of file diff --git a/unifold/colab/model.py b/unifold/colab/model.py index 7f95b95..6f3adde 100644 --- a/unifold/colab/model.py +++ b/unifold/colab/model.py @@ -32,6 +32,7 @@ def colab_inference( times: int, manual_seed: int, device: str = "cuda:0", + bf16=False ): if symmetry_group is not None: @@ -60,6 +61,8 @@ def colab_inference( state_dict = torch.load(param_path)["ema"]["params"] state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()} model.load_state_dict(state_dict) + if bf16: + model = model.bfloat16() model = model.to(device) model.eval() model.inference_mode() @@ -87,7 +90,7 @@ def colab_inference( chunk_size, block_size = automatic_chunk_size( seq_len, device, - is_bf16=False, + is_bf16=bf16, ) model.globals.chunk_size = chunk_size model.globals.block_size = block_size @@ -139,7 +142,9 @@ def to_float(x): plddts[cur_save_name] = str(mean_plddt) if is_multimer and symmetry_group is None: ptms[cur_save_name] = str(np.mean(out["iptm+ptm"])) - with open(os.path.join(output_dir, cur_save_name + '.pdb'), "w") as f: + + best_results_path = os.path.join(output_dir, cur_save_name + '.pdb') + with open(best_results_path, "w") as f: f.write(protein.to_pdb(cur_protein)) if is_multimer and symmetry_group is None: @@ -148,14 +153,16 @@ def to_float(x): best_result = { "protein": cur_protein, "plddt": out["plddt"], - "pae": out["predicted_aligned_error"] + "pae": out["predicted_aligned_error"], + 'best_results_path': best_results_path, } else: if mean_plddt>best_score: best_result = { "protein": cur_protein, "plddt": out["plddt"], - "pae": None + "pae": None, + 'best_results_path': best_results_path, } print("plddts", plddts) diff --git a/unifold/inference.py b/unifold/inference.py index 7cd2a8e..fb3a2a1 100644 --- a/unifold/inference.py +++ b/unifold/inference.py @@ -42,7 +42,7 @@ def automatic_chunk_size(seq_len, device, is_bf16): chunk_size = 32 block_size = 512 else: - chunk_size = 4 + chunk_size = 1 block_size = 256 return chunk_size, block_size