|
| 1 | +import random |
| 2 | +import math |
| 3 | +from sequences.columnReader import ColumnReader |
| 4 | + |
| 5 | +TYPE_BASIC = 0 |
| 6 | +TYPE_DUAL = 1 |
| 7 | + |
| 8 | +class TaskManager: |
| 9 | + |
| 10 | + def __init__(self, config, seed): |
| 11 | + |
| 12 | + self.config = config |
| 13 | + self.random = seed |
| 14 | + |
| 15 | + # How many shards to have per epoch |
| 16 | + self.shardsPerEpoch = config.get_int("mtl.shardsPerEpoch", 10) |
| 17 | + |
| 18 | + # Total number of epochs |
| 19 | + self.maxEpochs:Int = config.get_int("mtl.maxEpochs", 100) |
| 20 | + |
| 21 | + # Training patience in number of epochs |
| 22 | + self.epochPatience:Int = config.get_int("mtl.epochPatience", 5) |
| 23 | + |
| 24 | + # Array of all tasks to be managed |
| 25 | + self.tasks = self.readTasks() |
| 26 | + |
| 27 | + self.taskCount = len(self.tasks) |
| 28 | + self.indices = range(self.taskCount) |
| 29 | + |
| 30 | + # Training shards from all tasks |
| 31 | + self.shards = self.mkShards() |
| 32 | + |
| 33 | + # Construct training shards by interleaving shards from all tasks |
| 34 | + def mkShards(self): |
| 35 | + shardsByTasks = list() |
| 36 | + |
| 37 | + # construct the shards for each task |
| 38 | + for i in self.indices: |
| 39 | + shardsByTasks += [self.tasks[i].mkShards()] |
| 40 | + assert(len(shardsByTasks[i]) == self.shardsPerEpoch) |
| 41 | + |
| 42 | + # now interleave the tasks |
| 43 | + interleavedShards = list() |
| 44 | + for i in range(self.shardsPerEpoch): |
| 45 | + for j in self.indices: |
| 46 | + crtShard = shardsByTasks[j][i] |
| 47 | + interleavedShards += [crtShard] |
| 48 | + |
| 49 | + |
| 50 | + # print ("All shards:") |
| 51 | + # for(i <- interleavedShards.indices) |
| 52 | + # print (s"${interleavedShards(i)}") |
| 53 | + |
| 54 | + |
| 55 | + return interleavedShards |
| 56 | + |
| 57 | + # Iterator over all sentences coming from all interleaved shards |
| 58 | + def getSentences(self): |
| 59 | + return SentenceIterator(self.tasks, self.shards, self.random) |
| 60 | + |
| 61 | + # Reads all tasks from disk in memory |
| 62 | + def readTasks(self): |
| 63 | + numberOfTasks = self.config.get_int("mtl.numberOfTasks", None) |
| 64 | + tasks = list() |
| 65 | + for i in range(numberOfTasks): |
| 66 | + tasks += [self.readTask(i + 1)] |
| 67 | + |
| 68 | + print (f"Read {numberOfTasks} tasks from config file.") |
| 69 | + return tasks |
| 70 | + |
| 71 | + def readTask(self, taskNumber): |
| 72 | + taskName = self.config.get_string(f"mtl.task{taskNumber}.name", None) |
| 73 | + train = self.config.get_string(f"mtl.task{taskNumber}.train", None) |
| 74 | + |
| 75 | + dev = self.config.get_string(f"mtl.task{taskNumber}.dev", None) if f"mtl.task{taskNumber}.dev" in self.config else None |
| 76 | + test = self.config.get_string(f"mtl.task{taskNumber}.test", None) if f"mtl.task{taskNumber}.test" in self.config else None |
| 77 | + |
| 78 | + taskType = self.parseType(self.config.get_string(f"mtl.task{taskNumber}.type", "basic")) |
| 79 | + |
| 80 | + weight = self.config.get_float(f"mtl.task{taskNumber}.weight", 1.0) |
| 81 | + |
| 82 | + return Task(taskNumber - 1, taskName, taskType, self.shardsPerEpoch, weight, train, dev, test) |
| 83 | + |
| 84 | + def parseType(self, inf): |
| 85 | + if inf == "basic": return TYPE_BASIC |
| 86 | + elif inf == "dual": return TYPE_DUAL |
| 87 | + else: raise ValueError(f"ERROR: unknown task type {inf}!") |
| 88 | + |
| 89 | + def debugTraversal(self): |
| 90 | + for epoch in range(self.maxEpochs): |
| 91 | + print (f"Started epoch {epoch}") |
| 92 | + sentCount = 0 |
| 93 | + taskId = 0 |
| 94 | + totalSents = 0 |
| 95 | + for sentence in getSentences(): |
| 96 | + totalSents += 1 |
| 97 | + if(sentence[0] != taskId): |
| 98 | + print (f"Read {sentCount} sentences from task {taskId}") |
| 99 | + taskId = sentence[0] |
| 100 | + sentCount = 1 |
| 101 | + else: |
| 102 | + sentCount += 1 |
| 103 | + print (f"Read {sentCount} sentences from task {taskId}") |
| 104 | + print (f"Read {totalSents} sentences in epoch {epoch}.") |
| 105 | + |
| 106 | +class SentenceIterator(object): |
| 107 | + def __init__(tasks, shards, random): |
| 108 | + |
| 109 | + self.tasks = tasks |
| 110 | + self.shards = shards |
| 111 | + self.random = random #random seed |
| 112 | + |
| 113 | + # Offset in randomizedSentencePositions array |
| 114 | + self.sentenceOffset = 0 |
| 115 | + self.randomizedSentencePositions = randomizeSentences() |
| 116 | + |
| 117 | + class Sentence: |
| 118 | + def __init__(self, taskId, sentencePosition): |
| 119 | + self.taskId = taskId |
| 120 | + self.sentencePosition = sentencePosition |
| 121 | + |
| 122 | + # Randomizes all sentences across all tasks |
| 123 | + def randomizeSentences(): |
| 124 | + # first, randomize the shards |
| 125 | + random.seed(self.random) |
| 126 | + randomizedShards = random.shuffle(self.shards) |
| 127 | + randomizedSents = list() |
| 128 | + for shard in randomizedShards: |
| 129 | + # second, randomize the sentences inside each shard |
| 130 | + sents = random.shuffle(list(range(shard.startPosition, shard.endPosition))) |
| 131 | + for sent in sents: |
| 132 | + # store the randomized sentences |
| 133 | + randomizedSents += [Sentence(shard.taskId, sent)] |
| 134 | + return randomizedSents |
| 135 | + |
| 136 | + def __len__(self): |
| 137 | + return len(self.randomizedSentencePositions) |
| 138 | + |
| 139 | + def __iter__(self): |
| 140 | + return self |
| 141 | + |
| 142 | + def hasNext(self): return self.sentenceOffset < len(self.randomizedSentencePositions) |
| 143 | + |
| 144 | + def __next__(self): |
| 145 | + assert(self.sentenceOffset >= 0 and self.sentenceOffset < len(self.randomizedSentencePositions)) |
| 146 | + |
| 147 | + s = self.randomizedSentencePositions[sentenceOffset] |
| 148 | + tid = s.taskId |
| 149 | + sentence = self.tasks[tid].trainSentences[s.sentencePosition] |
| 150 | + self.sentenceOffset += 1 |
| 151 | + |
| 152 | + #print ("shardPosition = $shardPosition, sentencePosition = $sentencePosition") |
| 153 | + |
| 154 | + return (tid, sentence) |
| 155 | + |
| 156 | +class Shard: |
| 157 | + def __init__(self, taskId, startPosition, endPosition): |
| 158 | + self.taskId = taskId |
| 159 | + self.startPosition = startPosition |
| 160 | + self.endPosition = endPosition |
| 161 | + |
| 162 | +class Task: |
| 163 | + def __init__(self, |
| 164 | + taskId, # this starts at 0 so we can use it as an index in the array of tasks |
| 165 | + taskName:str, |
| 166 | + taskType:int, |
| 167 | + shardsPerEpoch:int, |
| 168 | + taskWeight:float, |
| 169 | + trainFileName:str, |
| 170 | + devFileName:str = None, |
| 171 | + testFileName:str = None): |
| 172 | + self.taskId = taskId |
| 173 | + taskNumber = taskId + 1 |
| 174 | + print (f"Reading task {taskNumber} ({taskName})...") |
| 175 | + self.trainSentences = ColumnReader.readColumns(trainFileName) |
| 176 | + self.devSentences = ColumnReader.readColumns(devFileName) if devFileName else None |
| 177 | + self.testSentences = ColumnReader.readColumns(testFileName) if testFileName else None |
| 178 | + |
| 179 | + self.isBasic:Boolean = taskType == TYPE_BASIC |
| 180 | + self.isDual:Boolean = taskType == TYPE_DUAL |
| 181 | + |
| 182 | + if taskType == TYPE_BASIC: |
| 183 | + self.prettyType = "basic" |
| 184 | + elif taskType == TYPE_DUAL: |
| 185 | + self.prettyType = "dual" |
| 186 | + else: |
| 187 | + self.prettyType = "unknown" |
| 188 | + |
| 189 | + # The size of the training shard for this task |
| 190 | + self.shardSize = math.ceil(len(self.trainSentences) / shardsPerEpoch) |
| 191 | + |
| 192 | + # Current position in the training sentences when we iterate during training |
| 193 | + currentTrainingSentencePosition = 0 |
| 194 | + |
| 195 | + print (f"============ starting task {taskNumber} ============") |
| 196 | + print (f"Read {len(self.trainSentences)} training sentences for task {taskNumber}, with shard size {self.shardSize}.") |
| 197 | + if(self.devSentences is not None): |
| 198 | + print (f"Read {len(self.devSentences)} development sentences for task {taskNumber}.") |
| 199 | + if(self.testSentences is not None): |
| 200 | + print (f"Read {len(self.testSentences)} testing sentences for task {taskNumber}.") |
| 201 | + print (f"Using taskWeight = {taskWeight}") |
| 202 | + print (f"Task type = {self.prettyType}.") |
| 203 | + print (f"============ completed task {taskNumber} ============") |
| 204 | + |
| 205 | + # Construct the shards from all training sentences in this task |
| 206 | + def mkShards(self): |
| 207 | + shards = list() |
| 208 | + crtPos = 0 |
| 209 | + while(crtPos < len(self.trainSentences)): |
| 210 | + endPos = min(crtPos + self.shardSize, len(self.trainSentences)) |
| 211 | + shards += [Shard(self.taskId, crtPos, endPos)] |
| 212 | + crtPos = endPos |
| 213 | + return shards |
0 commit comments