Skip to content

Commit d86fccb

Browse files
committed
init code
1 parent d4618cf commit d86fccb

File tree

7 files changed

+286
-0
lines changed

7 files changed

+286
-0
lines changed

main/src/main/python/__init__.py

Whitespace-only changes.

main/src/main/python/pytorch/__init__.py

Whitespace-only changes.

main/src/main/python/pytorch/metal.py

Whitespace-only changes.
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
import random
2+
import math
3+
from sequences.columnReader import ColumnReader
4+
5+
TYPE_BASIC = 0
6+
TYPE_DUAL = 1
7+
8+
class TaskManager:
9+
10+
def __init__(self, config, seed):
11+
12+
self.config = config
13+
self.random = seed
14+
15+
# How many shards to have per epoch
16+
self.shardsPerEpoch = config.get_int("mtl.shardsPerEpoch", 10)
17+
18+
# Total number of epochs
19+
self.maxEpochs:Int = config.get_int("mtl.maxEpochs", 100)
20+
21+
# Training patience in number of epochs
22+
self.epochPatience:Int = config.get_int("mtl.epochPatience", 5)
23+
24+
# Array of all tasks to be managed
25+
self.tasks = self.readTasks()
26+
27+
self.taskCount = len(self.tasks)
28+
self.indices = range(self.taskCount)
29+
30+
# Training shards from all tasks
31+
self.shards = self.mkShards()
32+
33+
# Construct training shards by interleaving shards from all tasks
34+
def mkShards(self):
35+
shardsByTasks = list()
36+
37+
# construct the shards for each task
38+
for i in self.indices:
39+
shardsByTasks += [self.tasks[i].mkShards()]
40+
assert(len(shardsByTasks[i]) == self.shardsPerEpoch)
41+
42+
# now interleave the tasks
43+
interleavedShards = list()
44+
for i in range(self.shardsPerEpoch):
45+
for j in self.indices:
46+
crtShard = shardsByTasks[j][i]
47+
interleavedShards += [crtShard]
48+
49+
50+
# print ("All shards:")
51+
# for(i <- interleavedShards.indices)
52+
# print (s"${interleavedShards(i)}")
53+
54+
55+
return interleavedShards
56+
57+
# Iterator over all sentences coming from all interleaved shards
58+
def getSentences(self):
59+
return SentenceIterator(self.tasks, self.shards, self.random)
60+
61+
# Reads all tasks from disk in memory
62+
def readTasks(self):
63+
numberOfTasks = self.config.get_int("mtl.numberOfTasks", None)
64+
tasks = list()
65+
for i in range(numberOfTasks):
66+
tasks += [self.readTask(i + 1)]
67+
68+
print (f"Read {numberOfTasks} tasks from config file.")
69+
return tasks
70+
71+
def readTask(self, taskNumber):
72+
taskName = self.config.get_string(f"mtl.task{taskNumber}.name", None)
73+
train = self.config.get_string(f"mtl.task{taskNumber}.train", None)
74+
75+
dev = self.config.get_string(f"mtl.task{taskNumber}.dev", None) if f"mtl.task{taskNumber}.dev" in self.config else None
76+
test = self.config.get_string(f"mtl.task{taskNumber}.test", None) if f"mtl.task{taskNumber}.test" in self.config else None
77+
78+
taskType = self.parseType(self.config.get_string(f"mtl.task{taskNumber}.type", "basic"))
79+
80+
weight = self.config.get_float(f"mtl.task{taskNumber}.weight", 1.0)
81+
82+
return Task(taskNumber - 1, taskName, taskType, self.shardsPerEpoch, weight, train, dev, test)
83+
84+
def parseType(self, inf):
85+
if inf == "basic": return TYPE_BASIC
86+
elif inf == "dual": return TYPE_DUAL
87+
else: raise ValueError(f"ERROR: unknown task type {inf}!")
88+
89+
def debugTraversal(self):
90+
for epoch in range(self.maxEpochs):
91+
print (f"Started epoch {epoch}")
92+
sentCount = 0
93+
taskId = 0
94+
totalSents = 0
95+
for sentence in getSentences():
96+
totalSents += 1
97+
if(sentence[0] != taskId):
98+
print (f"Read {sentCount} sentences from task {taskId}")
99+
taskId = sentence[0]
100+
sentCount = 1
101+
else:
102+
sentCount += 1
103+
print (f"Read {sentCount} sentences from task {taskId}")
104+
print (f"Read {totalSents} sentences in epoch {epoch}.")
105+
106+
class SentenceIterator(object):
107+
def __init__(tasks, shards, random):
108+
109+
self.tasks = tasks
110+
self.shards = shards
111+
self.random = random #random seed
112+
113+
# Offset in randomizedSentencePositions array
114+
self.sentenceOffset = 0
115+
self.randomizedSentencePositions = randomizeSentences()
116+
117+
class Sentence:
118+
def __init__(self, taskId, sentencePosition):
119+
self.taskId = taskId
120+
self.sentencePosition = sentencePosition
121+
122+
# Randomizes all sentences across all tasks
123+
def randomizeSentences():
124+
# first, randomize the shards
125+
random.seed(self.random)
126+
randomizedShards = random.shuffle(self.shards)
127+
randomizedSents = list()
128+
for shard in randomizedShards:
129+
# second, randomize the sentences inside each shard
130+
sents = random.shuffle(list(range(shard.startPosition, shard.endPosition)))
131+
for sent in sents:
132+
# store the randomized sentences
133+
randomizedSents += [Sentence(shard.taskId, sent)]
134+
return randomizedSents
135+
136+
def __len__(self):
137+
return len(self.randomizedSentencePositions)
138+
139+
def __iter__(self):
140+
return self
141+
142+
def hasNext(self): return self.sentenceOffset < len(self.randomizedSentencePositions)
143+
144+
def __next__(self):
145+
assert(self.sentenceOffset >= 0 and self.sentenceOffset < len(self.randomizedSentencePositions))
146+
147+
s = self.randomizedSentencePositions[sentenceOffset]
148+
tid = s.taskId
149+
sentence = self.tasks[tid].trainSentences[s.sentencePosition]
150+
self.sentenceOffset += 1
151+
152+
#print ("shardPosition = $shardPosition, sentencePosition = $sentencePosition")
153+
154+
return (tid, sentence)
155+
156+
class Shard:
157+
def __init__(self, taskId, startPosition, endPosition):
158+
self.taskId = taskId
159+
self.startPosition = startPosition
160+
self.endPosition = endPosition
161+
162+
class Task:
163+
def __init__(self,
164+
taskId, # this starts at 0 so we can use it as an index in the array of tasks
165+
taskName:str,
166+
taskType:int,
167+
shardsPerEpoch:int,
168+
taskWeight:float,
169+
trainFileName:str,
170+
devFileName:str = None,
171+
testFileName:str = None):
172+
self.taskId = taskId
173+
taskNumber = taskId + 1
174+
print (f"Reading task {taskNumber} ({taskName})...")
175+
self.trainSentences = ColumnReader.readColumns(trainFileName)
176+
self.devSentences = ColumnReader.readColumns(devFileName) if devFileName else None
177+
self.testSentences = ColumnReader.readColumns(testFileName) if testFileName else None
178+
179+
self.isBasic:Boolean = taskType == TYPE_BASIC
180+
self.isDual:Boolean = taskType == TYPE_DUAL
181+
182+
if taskType == TYPE_BASIC:
183+
self.prettyType = "basic"
184+
elif taskType == TYPE_DUAL:
185+
self.prettyType = "dual"
186+
else:
187+
self.prettyType = "unknown"
188+
189+
# The size of the training shard for this task
190+
self.shardSize = math.ceil(len(self.trainSentences) / shardsPerEpoch)
191+
192+
# Current position in the training sentences when we iterate during training
193+
currentTrainingSentencePosition = 0
194+
195+
print (f"============ starting task {taskNumber} ============")
196+
print (f"Read {len(self.trainSentences)} training sentences for task {taskNumber}, with shard size {self.shardSize}.")
197+
if(self.devSentences is not None):
198+
print (f"Read {len(self.devSentences)} development sentences for task {taskNumber}.")
199+
if(self.testSentences is not None):
200+
print (f"Read {len(self.testSentences)} testing sentences for task {taskNumber}.")
201+
print (f"Using taskWeight = {taskWeight}")
202+
print (f"Task type = {self.prettyType}.")
203+
print (f"============ completed task {taskNumber} ============")
204+
205+
# Construct the shards from all training sentences in this task
206+
def mkShards(self):
207+
shards = list()
208+
crtPos = 0
209+
while(crtPos < len(self.trainSentences)):
210+
endPos = min(crtPos + self.shardSize, len(self.trainSentences))
211+
shards += [Shard(self.taskId, crtPos, endPos)]
212+
crtPos = endPos
213+
return shards

main/src/main/python/run.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from pyhocon import ConfigFactory
2+
import argparse
3+
from pytorch.taskManager import TaskManager
4+
5+
if __name__ == '__main__':
6+
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument('--model_file', type=str, help='Filename of the model.')
9+
parser.add_argument('--train', action='store_true', help='Set the code to training purpose.')
10+
parser.add_argument('--test', action='store_true', help='Set the code to testing purpose.')
11+
parser.add_argument('--shell', action='store_true', help='Set the code to shell mode.')
12+
parser.add_argument('--config', type=str, help='Filename of the configuration.')
13+
parser.add_argument('--seed', type=int, default=1234)
14+
args = parser.parse_args()
15+
16+
if args.train:
17+
config = ConfigFactory.parse_file(f'../resources/org/clulab/{args.config}.conf')
18+
taskManager = TaskManager(config, args.seed)
19+
# modelName = args.model_file
20+
# mtl = Metal(taskManager, parameters, None)
21+
# mtl.train(modelName)
22+
elif args.test:
23+
pass
24+
elif args.shell:
25+
pass

main/src/main/python/sequences/__init__.py

Whitespace-only changes.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#-----------------------------------------------------------
2+
# Reads the CoNLL-like column format
3+
#-----------------------------------------------------------
4+
class ColumnReader:
5+
6+
def readColumns(source):
7+
if type(source) is str:
8+
source = open(source)
9+
sentence = list()
10+
sentences = list()
11+
for line in source:
12+
print (line)
13+
l = line.strip()
14+
if (l is ""):
15+
# end of sentence
16+
if (sentence):
17+
sentences += [sentence]
18+
sentence = list()
19+
else:
20+
# within the same sentence
21+
bits = l.split("\\s")
22+
if (len(bits) < 2):
23+
raise RuntimeError(f"ERROR: invalid line {l}!")
24+
sentence += Row(bits)
25+
26+
if (sentence):
27+
sentences += [sentence]
28+
29+
source.close()
30+
return sentences
31+
32+
# -----------------------------------------------------------
33+
# Stores training data for sequence modeling
34+
# Mandatory columns: 0 - word, 1 - label
35+
# Optional columns: 2 - POS tag, 3+ SRL arguments
36+
# @param tokens
37+
# -----------------------------------------------------------
38+
39+
class Row:
40+
41+
def __init__(self, tokens):
42+
self.tokens = tokens
43+
self.length = len(tokens)
44+
45+
def get(self, idx):
46+
if(idx >= self.length):
47+
raise RuntimeError(f"ERROR: trying to read field #{idx}, which does not exist in this row: {tokens}!")
48+
return tokens[idx]

0 commit comments

Comments
 (0)