Skip to content

Commit 03229a4

Browse files
committed
forward layer implementation
1 parent ddcf223 commit 03229a4

File tree

6 files changed

+216
-1
lines changed

6 files changed

+216
-1
lines changed

main/src/main/python/pytorch/embeddingLayer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def mkEmbeddings(self, words, constEmbeddings, tags=None, nes=None, headPosition
113113
#
114114
# biLSTM over character embeddings
115115
#
116-
charEmbedding = torch.cat([mkCharacterEmbedding(word, c2i, self.charLookupParameters, self.charRnnBuilder) for word in words])
116+
charEmbedding = torch.stack([mkCharacterEmbedding(word, c2i, self.charLookupParameters, self.charRnnBuilder) for word in words])
117117

118118
#
119119
# POS tag embedding
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class FinalLayer(nn.Module):
5+
6+
def __init__(self):
7+
super().__init__()
8+
self.inDim = None
9+
self.outDim = None
10+
11+
def forward(self, inputExpressions, headPositionsOpt, doDropout):
12+
raise NotImplementedError
13+
14+
def loss(self, emissionScoresAsExpression, goldLabels):
15+
raise NotImplementedError
16+
17+
def inference(self, emissionScores):
18+
raise NotImplementedError
19+
20+
def inferenceWithScores(self, emissionScores):
21+
raise NotImplementedError
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import torch
2+
import torch.nn
3+
4+
from finalLayer import FinalLayer
5+
from greedyForwardLayer import GreedyForwardLayer
6+
from viterbiForwardLayer import ViterbiForwardLayer
7+
8+
from utils import *
9+
10+
def ForwardLayer(FinalLayer):
11+
def __init__(self, inputSize, isDual, t2i, i2t, actualInputSize, nonlinearity, dropoutProb, spans = None):
12+
self.inputSize = inputSize
13+
self.isDual = isDual
14+
self.t2i = t2i
15+
self.i2t = i2t
16+
self.spans = spans
17+
self.nonlinearity = nonlinearity
18+
19+
self.pH = nn.Linear(actualInputSize, len(t2i))
20+
self.pRoot = torch.rand(inputSize) #TODO: Not sure about the shape here
21+
self.dropoutProb = dropoutProb
22+
23+
self.inDim = spanLength(spans) if spans is not None else inputSize
24+
self.outDim = len(t2i)
25+
26+
27+
def pickSpan(self, v):
28+
if self.spans is None:
29+
return v
30+
else:
31+
# Zheng: Will spans overlap?
32+
vs = list()
33+
for span in self.spans:
34+
e = torch.index_select(v, 0, torch.tensor([span[0], span[1]]))
35+
vs.append(e)
36+
return torch.cat(vs)
37+
38+
def forward(inputExpressions, doDropout, headPositionsOpt = None):
39+
emissionScores = list()
40+
if not self.isDual:
41+
# Zheng: Why the for loop here? Can we just use matrix manipulation?
42+
for i, e in enumerate(inputExpressions):
43+
argExp = expressionDropout(self.pickSpan(e), self.dropoutProb, doDropout)
44+
l1 = expressionDropout(self.pH(argExp), self.dropoutProb, doDropout)
45+
if nonlinearity == NONLIN_TANH:
46+
l1 = torch.tanh(l1)
47+
elif nonlinearity == NONLIN_RELU:
48+
l1 = torch.relu(l1)
49+
emissionScores.append(l1)
50+
else:
51+
if headPositionsOpt is None:
52+
raise RuntimeError("ERROR: dual task without information about head positions!")
53+
for i, e in enumerate(inputExpressions):
54+
headPosition = headPositionsOpt[i]
55+
argExp = expressionDropout(pickSpan(e), self.dropoutProb, doDropout)
56+
if headPosition >= 0:
57+
# there is an explicit head in the sentence
58+
predExp = expressionDropout(pickSpan(inputExpressions[headPosition]), self.dropout, doDropout)
59+
else:
60+
# the head is root. we used a dedicated Parameter for root
61+
# Zheng: Why not add root node to the input sequence at the beginning?
62+
predExp = expressionDropout(pickSpan(self.pRoot), self.dropout, doDropout)
63+
ss = torch.cat([argExp, predExp])
64+
l1 = expressionDropout(self.pH(ss), self.dropoutProb, doDropout)
65+
if nonlinearity == NONLIN_TANH:
66+
l1 = torch.tanh(l1)
67+
elif nonlinearity == NONLIN_RELU:
68+
l1 = torch.relu(l1)
69+
emissionScores.append(l1)
70+
return torch.stack(emissionScores)
71+
72+
@staticmethod
73+
def load(x2i):
74+
inferenceType = x2i["inferenceType"]
75+
if inferenceType == TYPE_VITERBI:
76+
pass
77+
# TODO
78+
# return ViterbiForwardLayer.load(x2i)
79+
elif inferenceType == TYPE_GREEDY:
80+
return GreedyForwardLayer.load(x2i)
81+
else:
82+
raise RuntimeError(f"ERROR: unknown forward layer type {inferenceType}!")
83+
84+
@staticmethod
85+
def initialize(config, paramPrefix, labelCounter, isDual, inputSize):
86+
if(not config.__contains__(paramPrefix)):
87+
return None
88+
89+
inferenceType = config.get_string(paramPrefix + ".inference", "greedy")
90+
dropoutProb = config.get_float(paramPrefix + ".dropoutProb", DEFAULT_DROPOUT_PROBABILITY)
91+
92+
nonlinAsString = config.get_string(paramPrefix + ".nonlinearity", "")
93+
if nonlinAsString in nonlin_map:
94+
nonlin = nonlin_map[nonlinAsString]
95+
else:
96+
raise RuntimeError(f"ERROR: unknown non-linearity {nonlinAsString}!")
97+
98+
t2i = {t:i for i, t in enumerate(labelCounter.keys())}
99+
i2t = {i:t for t, i in t2i.items()}
100+
101+
spanConfig = config.get_string(paramPrefix + ".span", "")
102+
if spanConfig is "":
103+
span = None
104+
else:
105+
span = parseSpan(spanConfig)
106+
107+
if span:
108+
l = spanLength(span)
109+
actualInputSize = 2*l if isDual else l
110+
else:
111+
actualInputSize = 2*inputSize if isDual else inputSize
112+
113+
if inferenceType == TYPE_GREEDY_STRING:
114+
return GreedyForwardLayer(inputSize, isDual, t2i, i2t, actualInputSize, span, nonlin, dropoutProb)
115+
elif inferenceType == TYPE_VITERBI_STRING:
116+
pass
117+
# TODO
118+
# layer = ViterbiForwardLayer(inputSize, isDual, t2i, i2t, actualInputSize, span, nonlin, dropoutProb)
119+
# layer.initializeTransitions()
120+
# return layer
121+
else:
122+
raise RuntimeError(f"ERROR: unknown inference type {inferenceType}!")
123+
124+
def spanLength(spans):
125+
s = 0
126+
for x in spans:
127+
s += x[1] - x[0]
128+
return s
129+
130+
def parseSpan(spanParam, inputSize):
131+
spans = list()
132+
spanParamTokens = spanParam.split(",")
133+
for spanParamToken in spanParamTokens:
134+
spanTokens = spanParamToken.split('-')
135+
assert(len(spanTokens) == 2)
136+
spans.append((int(spanTokens[0]), int(spanTokens[1])))
137+
return spans
138+
139+
def spanToString(spans):
140+
s = ""
141+
first = True
142+
for span in spans:
143+
if not first:
144+
s += ","
145+
s += f"{span[0]}-{span[1]}"
146+
first = False
147+
return s
148+
149+
150+
151+
152+
153+
154+
155+
156+
157+
158+
159+
160+
161+
162+
163+
164+
165+
166+
167+
168+
169+
170+
171+
172+
173+

main/src/main/python/pytorch/greedyForwardLayer.py

Whitespace-only changes.

main/src/main/python/pytorch/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@
2020

2121
IS_DYNET_INITIALIZED = False
2222

23+
TYPE_VITERBI = 1
24+
TYPE_GREEDY = 2
25+
26+
NONLIN_NONE = 0
27+
NONLIN_RELU = 1
28+
NONLIN_TANH = 2
29+
30+
nonlin_map = {"relu":NONLIN_RELU, "tanh":NONLIN_TANH, "":NONLIN_NONE}
31+
32+
TYPE_GREEDY_STRING = "greedy"
33+
TYPE_VITERBI_STRING = "viterbi"
34+
35+
DEFAULT_IS_DUAL = 0
36+
2337
def save(file, values, comment):
2438
file.write("# " + comment + "\n")
2539
for key, value in values.items():
@@ -71,5 +85,12 @@ def transduce(embeddings, builder):
7185

7286
return output, result
7387

88+
def expressionDropout(expression, dropoutProb, doDropout):
89+
if doDropout and dropoutProb > 0:
90+
dropout = nn.Dropout(dropoutProb)
91+
return dropout(expression)
92+
else:
93+
return expression
94+
7495

7596

main/src/main/python/pytorch/viterbiForwardLayer.py

Whitespace-only changes.

0 commit comments

Comments
 (0)