Skip to content

Commit 586d771

Browse files
authored
Merge pull request #15 from Bitbol-Lab/proteingym_branch
Fix number of issues with version, add docker etc
2 parents d6773c5 + ffec189 commit 586d771

24 files changed

+352
-3723
lines changed

Dockerfile

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Using an official Python runtime with CUDA support as a parent image (https://hub.docker.com/r/nvidia/cuda/)
2+
FROM pytorch/pytorch:2.3.1-cuda11.8-cudnn8-devel as base
3+
4+
ENV HOST docker
5+
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
6+
ENV TZ Europe/Paris
7+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
8+
9+
# TO CHANGE IF NEEDED
10+
RUN useradd -m -U -s /bin/bash user
11+
RUN apt-get update && apt-get install -y git
12+
13+
# Set the working directory: TO CHANGE IF NEEDED
14+
USER user
15+
16+
# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
17+
ENV PIP_NO_CACHE_DIR=1
18+
#install from requirements.txt
19+
COPY requirements.txt /home/user/requirements.txt
20+
21+
#apt get git
22+
RUN pip install --no-cache-dir -r /home/user/requirements.txt
23+
24+
WORKDIR /home/user

ProtMamba_ssm/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
__version__ = "0.0.1"
2-
from .core import *
32
from .dataloaders import *
43
from .fim import *
54
from .modules import *

ProtMamba_ssm/_modidx.py

Lines changed: 0 additions & 173 deletions
This file was deleted.

ProtMamba_ssm/dataloaders.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_dataloaders.ipynb.
2-
3-
# %% auto 0
4-
__all__ = ['Uniclust30_Dataset', 'make_dataloader', 'DataCollatorForUniclust30Dataset']
5-
6-
# %% ../nbs/03_dataloaders.ipynb 3
71
from .utils import AA_TO_ID
82
from .fim import NoFIM, SingleSpanFIM, MultipleSpanFIM
93
import pickle
@@ -14,8 +8,6 @@
148
from dataclasses import dataclass
159
from typing import Dict, Sequence
1610

17-
18-
# %% ../nbs/03_dataloaders.ipynb 4
1911
# Make dataset
2012
class Uniclust30_Dataset(Dataset):
2113
"""
@@ -78,14 +70,14 @@ def __init__(self, filename="encoded_MSAs_train.pkl",
7870
def __len__(self):
7971
return len(self.cluster_names)
8072

81-
def __getitem__(self, idx):
73+
def __getitem__(self, idx, shuffle=True):
8274
# get all the sequences in the cluster
8375
sequences = self.get_sequences(idx)
8476
# get total number of sequences in the cluster and choose how many to sample
8577
orig_num_sequences = len(self.get_index_start_of_sequences(sequences))
8678
num_sequences = np.random.randint(1, orig_num_sequences + 1) if self.sample else orig_num_sequences
8779
# sample the sequences
88-
sequences, position_ids = self.sample_sequences(sequences, num_sequences)
80+
sequences, position_ids = self.sample_sequences(sequences, num_sequences, shuffle=shuffle)
8981
# with probability 0.5, reverse the sequences and move the last token to the front
9082
sequences, position_ids = self.reverse_sequences(sequences, position_ids) if (
9183
self.reverse and np.random.rand() > 0.5) else sequences, position_ids
@@ -128,7 +120,7 @@ def reverse_sequences(self, sequence, position_ids=None):
128120
return np.concatenate([sequence[-1:], sequence[:-1]]), np.concatenate(
129121
[position_ids[-1:], position_ids[:-1]]) if position_ids is not None else None
130122

131-
def sample_sequences(self, sequences, num_sequences, shuffle=True):
123+
def sample_sequences(self, sequences, num_sequences, shuffle=True, which_seqs=None):
132124
"""Sample `num_sequences` from the sequences in the cluster."""
133125
L = len(sequences)
134126
# get the indexes of the start of each sequence
@@ -137,10 +129,11 @@ def sample_sequences(self, sequences, num_sequences, shuffle=True):
137129
assert len(inds) > 0, "No sequences found in cluster."
138130
assert len(inds) >= num_sequences, "Not enough sequences in cluster."
139131
# sample n_sequences randomly from the sequences
140-
if shuffle:
141-
which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
142-
else:
143-
which_seqs = np.arange(len(inds))[-num_sequences:]
132+
if which_seqs is None:
133+
if shuffle:
134+
which_seqs = np.random.choice(np.arange(len(inds)), num_sequences, replace=False)
135+
else:
136+
which_seqs = np.arange(len(inds))[-num_sequences:]
144137
# get the tuples of start and end indexes of the sequences
145138
tuples = [(inds[i], inds[i + 1]) if i < len(inds) - 1 else (inds[i], L) for i in which_seqs]
146139
if self.troubleshoot:
@@ -154,9 +147,7 @@ def make_dataloader(dataset):
154147
"""Basic function to make a dataloader.
155148
"""
156149
dataloader = DataLoader(dataset)
157-
return dataloader
158150

159-
# %% ../nbs/03_dataloaders.ipynb 5
160151
@dataclass
161152
class DataCollatorForUniclust30Dataset(object):
162153
"""

ProtMamba_ssm/fim.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
1-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_fim.ipynb.
2-
3-
# %% auto 0
4-
__all__ = ['AbstractFIM', 'NoFIM', 'SingleSpanFIM', 'MultipleSpanFIM']
5-
6-
# %% ../nbs/04_fim.ipynb 3
71
from .utils import MASK_TO_ID, AA_TO_ID
82
import numpy as np
93

10-
# %% ../nbs/04_fim.ipynb 4
114
class AbstractFIM(object):
125
def __init__(self,
136
max_patches=5,

ProtMamba_ssm/modules.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,7 @@
1-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_modules.ipynb.
2-
3-
# %% auto 0
4-
__all__ = ['MambaConfig', 'sample_safe', 'decode_safe', 'GenerationMixinSafe', 'CheckpointedModule', 'create_block',
5-
'MixerModelSafe', 'MambaLMHeadModelSafe', 'MixerModelWithPosids', 'MixerModelWith2DPosids',
6-
'MambaLMHeadModelwithPosids', 'MambaLMHeadModelwith2DPosids', 'load_model']
7-
8-
# %% ../nbs/01_modules.ipynb 3
9-
import torch
10-
import torch.nn as nn
1+
import torch.nn as nn
112

123
import json
134
import os
14-
from collections import namedtuple
15-
from dataclasses import field, dataclass
16-
from functools import partial
175

186
from mamba_ssm.models.config_mamba import MambaConfig
197
from mamba_ssm.modules.block import Block
@@ -722,10 +710,11 @@ def protected_forward(self, input_ids, position_ids=None, inference_params=None,
722710
if num_last_tokens > 0:
723711
hidden_states = hidden_states[:, -num_last_tokens:]
724712
lm_logits = self.lm_head(hidden_states)
725-
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "hidden_states"])
726713
if len(save_layer) > 0:
714+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits", "hidden_states"])
727715
return CausalLMOutput(loss=None, logits=lm_logits, hidden_states=hidden_states)
728-
return CausalLMOutput(loss=None, logits=lm_logits, hidden_states=None)
716+
CausalLMOutput = namedtuple("CausalLMOutput", ["loss", "logits"])
717+
return CausalLMOutput(loss=None, logits=lm_logits)
729718

730719
@classmethod
731720
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, checkpoint_mixer=False, **kwargs):

ProtMamba_ssm/trainer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
1-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_trainer.ipynb.
2-
3-
# %% auto 0
4-
__all__ = ['PREFIX_CHECKPOINT_DIR', 'MambaTrainer', 'get_last_checkpoint', 'EarlyStoppingCallback']
5-
6-
# %% ../nbs/02_trainer.ipynb 3
7-
from transformers import Trainer, TrainerCallback, TrainerState, TrainerControl
1+
from transformers import Trainer, TrainerCallback
82
from .utils import *
93
import re
104
import torch

ProtMamba_ssm/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/99_utils.ipynb.
2-
3-
# %% auto 0
41
__all__ = ['AA_TO_ID', 'MASK_TO_ID', 'ID_TO_AA', 'encode_sequence', 'decode_sequence', 'clean_sequence', 'tokenizer',
52
'reorder_masked_sequence', 'load_from_file', 'generate_sequence', 'prepare_dataset_for_fim_generation',
63
'prepare_tokens', 'prepare_target', 'load_tensorboard_data', 'filter_datapoints', 'save_to_tensorboard',
74
'merge_loggings', 'concatenate_loggings', 'print_number_of_parameters', 'find_fim_indices',
85
'compute_metrics']
96

10-
# %% ../nbs/99_utils.ipynb 3
117
# Constants
128
AA_TO_ID = {'<cls>': 0,
139
'<pad>': 1,
@@ -53,7 +49,6 @@
5349

5450
ID_TO_AA = {v: k for k, v in AA_TO_ID.items()}
5551

56-
# %% ../nbs/99_utils.ipynb 4
5752
import numpy as np
5853
import torch
5954
from Bio import SeqIO
@@ -271,7 +266,6 @@ def prepare_target(target, use_fim=None):
271266
assert new_target.shape[1] == new_pos_ids.shape[1]
272267
return new_target, new_pos_ids, is_fim_dict
273268

274-
# %% ../nbs/99_utils.ipynb 5
275269
from tensorboard.backend.event_processing import event_accumulator
276270
from tensorboard.backend.event_processing.event_accumulator import ScalarEvent
277271
from torch.utils.tensorboard import SummaryWriter

configs/default_config.yaml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
data_dir: ""
2-
output_dir: "results/"
1+
data_dir: "/home/user/data/" #"/home/malbrank/nas/protmamba2/data/" and "/nvme1/common/OpenProteinSet/"
2+
output_dir: "/home/user/results/"
33
namedir: "test0"
4-
train_dataset_path: "encoded_MSAs_train.pkl"
4+
train_dataset_path: "encoded_MSAs_subset-100.pkl"
55
eval_dataset_path: '...'
6-
batch_size: 4 # mamba trained with total 0.5M tokens per batch
6+
batch_size: 2 # mamba trained with total 0.5M tokens per batch
77
d_model: 1024
8-
gradient_accumulation_steps: 8
9-
checkpoint_mixer: True
8+
gradient_accumulation_steps: 1
9+
checkpoint_mixer: False
1010
learning_rate: 0.0006 # mamba default (x5), decrease to 0.0006 for sizes > 100M params and 0.0002 for sizes > 1B params
1111
weight_decay: 0.1 # mamba default
1212
beta1: 0.9 # mamba default
@@ -25,9 +25,9 @@ seed_sequence_sampling: 42
2525
seed_datasets: 0
2626
save_steps: 250
2727
eval_steps: 50
28-
size_validation: 192
28+
size_validation: 4
2929
logging_steps: 10
30-
eval_accumulation_steps: 200
30+
eval_accumulation_steps: 10
3131
save_total_limit: 50
3232
dtype: "bfloat16"
3333
fim_strategy: "multiple_span" #["no-scramble", "one_span", "multiple_span"]

0 commit comments

Comments
 (0)