1- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_dataloaders.ipynb.
2-
3- # %% auto 0
4- __all__ = ['Uniclust30_Dataset' , 'make_dataloader' , 'DataCollatorForUniclust30Dataset' ]
5-
6- # %% ../nbs/03_dataloaders.ipynb 3
71from .utils import AA_TO_ID
82from .fim import NoFIM , SingleSpanFIM , MultipleSpanFIM
93import pickle
148from dataclasses import dataclass
159from typing import Dict , Sequence
1610
17-
18- # %% ../nbs/03_dataloaders.ipynb 4
1911# Make dataset
2012class Uniclust30_Dataset (Dataset ):
2113 """
@@ -78,14 +70,14 @@ def __init__(self, filename="encoded_MSAs_train.pkl",
7870 def __len__ (self ):
7971 return len (self .cluster_names )
8072
81- def __getitem__ (self , idx ):
73+ def __getitem__ (self , idx , shuffle = True ):
8274 # get all the sequences in the cluster
8375 sequences = self .get_sequences (idx )
8476 # get total number of sequences in the cluster and choose how many to sample
8577 orig_num_sequences = len (self .get_index_start_of_sequences (sequences ))
8678 num_sequences = np .random .randint (1 , orig_num_sequences + 1 ) if self .sample else orig_num_sequences
8779 # sample the sequences
88- sequences , position_ids = self .sample_sequences (sequences , num_sequences )
80+ sequences , position_ids = self .sample_sequences (sequences , num_sequences , shuffle = shuffle )
8981 # with probability 0.5, reverse the sequences and move the last token to the front
9082 sequences , position_ids = self .reverse_sequences (sequences , position_ids ) if (
9183 self .reverse and np .random .rand () > 0.5 ) else sequences , position_ids
@@ -128,7 +120,7 @@ def reverse_sequences(self, sequence, position_ids=None):
128120 return np .concatenate ([sequence [- 1 :], sequence [:- 1 ]]), np .concatenate (
129121 [position_ids [- 1 :], position_ids [:- 1 ]]) if position_ids is not None else None
130122
131- def sample_sequences (self , sequences , num_sequences , shuffle = True ):
123+ def sample_sequences (self , sequences , num_sequences , shuffle = True , which_seqs = None ):
132124 """Sample `num_sequences` from the sequences in the cluster."""
133125 L = len (sequences )
134126 # get the indexes of the start of each sequence
@@ -137,10 +129,11 @@ def sample_sequences(self, sequences, num_sequences, shuffle=True):
137129 assert len (inds ) > 0 , "No sequences found in cluster."
138130 assert len (inds ) >= num_sequences , "Not enough sequences in cluster."
139131 # sample n_sequences randomly from the sequences
140- if shuffle :
141- which_seqs = np .random .choice (np .arange (len (inds )), num_sequences , replace = False )
142- else :
143- which_seqs = np .arange (len (inds ))[- num_sequences :]
132+ if which_seqs is None :
133+ if shuffle :
134+ which_seqs = np .random .choice (np .arange (len (inds )), num_sequences , replace = False )
135+ else :
136+ which_seqs = np .arange (len (inds ))[- num_sequences :]
144137 # get the tuples of start and end indexes of the sequences
145138 tuples = [(inds [i ], inds [i + 1 ]) if i < len (inds ) - 1 else (inds [i ], L ) for i in which_seqs ]
146139 if self .troubleshoot :
@@ -154,9 +147,7 @@ def make_dataloader(dataset):
154147 """Basic function to make a dataloader.
155148 """
156149 dataloader = DataLoader (dataset )
157- return dataloader
158150
159- # %% ../nbs/03_dataloaders.ipynb 5
160151@dataclass
161152class DataCollatorForUniclust30Dataset (object ):
162153 """
0 commit comments