|
1 | 1 | import torch.nn as nn
|
2 | 2 | from utils import *
|
3 | 3 | from embeddingLayer import EmbeddingLayer
|
| 4 | +from constEmbeddingsGlove import ConstEmbeddingsGlove |
4 | 5 |
|
5 |
| -class Layers(nn.Module): |
| 6 | +class Layers(object): |
6 | 7 | def __init__(self, initialLayer, intermediateLayers, finalLayer):
|
7 |
| - super().__init__() |
8 | 8 | if finalLayer:
|
9 | 9 | self.outDim = finalLayer.outDim
|
10 | 10 | elif intermediateLayers:
|
@@ -37,10 +37,10 @@ def __str__(self):
|
37 | 37 | s += "final = " + finalLayer
|
38 | 38 | return s
|
39 | 39 |
|
40 |
| - def forward(self, sentence, constEmnbeddings, doDropout): |
| 40 | + def forward(self, sentence, constEmbeddings, doDropout): |
41 | 41 | if self.initialLayer.isEmpty:
|
42 | 42 | raise RuntimeError(f"ERROR: you can't call forward() on a Layers object that does not have an initial layer: {self}!")
|
43 |
| - states = self.initialLayer(sentence, constEmnbeddings, doDropout) |
| 43 | + states = self.initialLayer(sentence, constEmbeddings, doDropout) |
44 | 44 | for intermediateLayer in self.intermediateLayers:
|
45 | 45 | states = intermediateLayer(states, doDropout)
|
46 | 46 | if self.finalLayer.nonEmpty:
|
@@ -133,18 +133,103 @@ def loadX2i(cls, x2i):
|
133 | 133 |
|
134 | 134 | return cls(initialLayer, intermediateLayers, finalLayer)
|
135 | 135 |
|
136 |
| - def predictJointly(layers, sentence, constEmnbeddings): |
137 |
| - TODO |
138 |
| - def forwardForTask(layers, taskId, sentence, constEmnbeddings, doDropout): |
139 |
| - TODO |
140 |
| - def predict(layers, taskId, sentence, constEmnbeddings): |
141 |
| - TODO |
142 |
| - def predictWithScores(layers, taskId, sentence, constEmnbeddings): |
143 |
| - TODO |
144 |
| - def parse(layers, sentence, constEmnbeddings): |
145 |
| - TODO |
| 136 | + @staticmethod |
| 137 | + def predictJointly(layers, sentence, constEmbeddings): |
| 138 | + labelsPerTask = list() |
| 139 | + # layers(0) contains the shared layers |
| 140 | + if layers[0]: |
| 141 | + sharedStates = layers[0].forward(sentence, constEmbeddings, doDropout=False) |
| 142 | + for i in range(1, len(layers)): |
| 143 | + states = layers[i].forwardFrom(sharedStates, sentence.headPositions, doDropout=False) |
| 144 | + emissionScores = emissionScoresToArrays(states) |
| 145 | + labels = layers[i].finalLayer.inference(emissionScores) |
| 146 | + labelsPerTask += [labels] |
| 147 | + # no shared layer |
| 148 | + else: |
| 149 | + for i in range(1, len(layers)): |
| 150 | + states = layers[i].forward(sentence, sentence.headPositions, doDropout=False) |
| 151 | + emissionScores = emissionScoresToArrays(states) |
| 152 | + labels = layers[i].finalLayer.inference(emissionScores) |
| 153 | + labelsPerTask += [labels] |
| 154 | + |
| 155 | + return labelsPerTask |
| 156 | + |
| 157 | + @staticmethod |
| 158 | + def forwardForTask(layers, taskId, sentence, constEmbeddings, doDropout): |
| 159 | + if layers[0]: |
| 160 | + sharedStates = layers[0].forward(sentence, constEmbeddings, doDropout) |
| 161 | + states = layers[taskId+1].forwardFrom(sharedStates, sentence.headPositions, doDropout) |
| 162 | + else: |
| 163 | + states = layers[taskId+1].forward(sentence, constEmbeddings, doDropout) |
| 164 | + return states |
| 165 | + |
| 166 | + @staticmethod |
| 167 | + def predict(layers, taskId, sentence, constEmbeddings): |
| 168 | + states = Layers.forwardForTask(layers, taskId, sentence, constEmbeddings, doDropout=False) |
| 169 | + emissionScores = emissionScoresToArrays(states) |
| 170 | + return layers[taskId+1].finalLayer.inference(emissionScores) |
| 171 | + |
| 172 | + @staticmethod |
| 173 | + def predictWithScores(layers, taskId, sentence, constEmbeddings): |
| 174 | + states = Layers.forwardForTask(layers, taskId, sentence, constEmbeddings, doDropout=False) |
| 175 | + emissionScores = emissionScoresToArrays(states) |
| 176 | + return layers[taskId+1].finalLayer.inferenceWithScores(emissionScores) |
| 177 | + |
| 178 | + @staticmethod |
| 179 | + def parse(layers, sentence, constEmbeddings): |
| 180 | + # |
| 181 | + # first get the output of the layers that are shared between the two tasks |
| 182 | + # |
| 183 | + assert(layers[0].nonEmpty) |
| 184 | + sharedStates = layers[0].forward(sentence, constEmbeddings, doDropout=False) |
| 185 | + |
| 186 | + # |
| 187 | + # now predict the heads (first task) |
| 188 | + # |
| 189 | + headStates = layers[1].forwardFrom(sharedStates, None, doDropout=False) |
| 190 | + headEmissionScores = emissionScoresToArrays(headStates) |
| 191 | + headScores = layers[1].finalLayer.inference(headEmissionScores) |
| 192 | + |
| 193 | + # store the head values here |
| 194 | + heads = list() |
| 195 | + for wi, predictionsForThisWord in enumerate(headScores): |
| 196 | + # pick the prediction with the highest score, which is within the boundaries of the current sentence |
| 197 | + done = False |
| 198 | + for hi, relative in enumerate(predictionsForThisWord): |
| 199 | + if done: |
| 200 | + break |
| 201 | + try: |
| 202 | + relativeHead = int(relative[0]) |
| 203 | + if relativeHead == 0: |
| 204 | + heads.append(1) |
| 205 | + done = True |
| 206 | + else: |
| 207 | + headPosition = wi + relativeHead |
| 208 | + heads.append(headPosition) |
| 209 | + done = True |
| 210 | + except: |
| 211 | + raise RuntimeError('''some valid predictions may not be integers, e.g., "<STOP>" may be predicted by the sequence model''') |
| 212 | + if not done: |
| 213 | + # we should not be here, but let's be safe |
| 214 | + # if nothing good was found, assume root |
| 215 | + heads.append(1) |
| 216 | + |
| 217 | + # |
| 218 | + # next, predict the labels using the predicted heads |
| 219 | + # |
| 220 | + labelStates = layers[2].forwardFrom(sharedStates, heads, doDropout=False) |
| 221 | + emissionScores = emissionScoresToArrays(labelStates) |
| 222 | + labels = layers[2].finalLayer.inference(emissionScores) |
| 223 | + assert(len(labels)==len(heads)) |
| 224 | + |
| 225 | + return zip(heads, labels) |
| 226 | + |
| 227 | + @staticmethod |
146 | 228 | def loss(layers, taskId, sentence, goldLabels):
|
147 |
| - TODO |
| 229 | + # Zheng: I am not sure this is the suitable way to load embeddings or not, need help... |
| 230 | + constEmbeddings = ConstEmbeddingsGlove().mkConstLookupParams(sentence.words) |
| 231 | + states = Layers.forwardForTask(layers, taskId, sentence, constEmbeddings, doDropout=True) # use dropout during training! |
| 232 | + return layers[taskId+1].finalLayer.loss(states, goldLabels) |
148 | 233 |
|
149 | 234 |
|
150 | 235 |
|
|
0 commit comments