Skip to content

Commit 7984a90

Browse files
committed
Merge branch 'master' of https://github.com/fastnlp/fastNLP into local-fix-doc
2 parents 1c8bca5 + 27ae52c commit 7984a90

File tree

4 files changed

+161
-32
lines changed

4 files changed

+161
-32
lines changed

fastNLP/core/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,15 @@ def apply(self, func, new_field_name=None, **kwargs):
254254
:return results: if new_field_name is not passed, returned values of the function over all instances.
255255
"""
256256
results = [func(ins) for ins in self._inner_iter()]
257-
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
258-
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
259257

260258
extra_param = {}
261259
if 'is_input' in kwargs:
262260
extra_param['is_input'] = kwargs['is_input']
263261
if 'is_target' in kwargs:
264262
extra_param['is_target'] = kwargs['is_target']
265263
if new_field_name is not None:
264+
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
265+
raise ValueError("{} always return None.".format(get_func_signature(func=func)))
266266
if new_field_name in self.field_arrays:
267267
# overwrite the field, keep same attributes
268268
old_field = self.field_arrays[new_field_name]

fastNLP/io/embed_loader.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,18 @@ def load_embedding(emb_dim, emb_file, emb_type, vocab):
7575

7676
@staticmethod
7777
def parse_glove_line(line):
78-
line = list(filter(lambda w: len(w) > 0, line.strip().split(" ")))
78+
line = line.split()
7979
if len(line) <= 2:
8080
raise RuntimeError("something goes wrong in parsing glove embedding")
81-
return line[0], torch.Tensor(list(map(float, line[1:])))
81+
return line[0], line[1:]
82+
83+
@staticmethod
84+
def str_list_2_vec(line):
85+
try:
86+
return torch.Tensor(list(map(float, line)))
87+
except Exception:
88+
raise RuntimeError("something goes wrong in parsing glove embedding")
89+
8290

8391
@staticmethod
8492
def fast_load_embedding(emb_dim, emb_file, vocab):
@@ -99,6 +107,7 @@ def fast_load_embedding(emb_dim, emb_file, vocab):
99107
for line in f:
100108
word, vector = EmbedLoader.parse_glove_line(line)
101109
if word in vocab:
110+
vector = EmbedLoader.str_list_2_vec(vector)
102111
if len(vector.shape) > 1 or emb_dim != vector.shape[0]:
103112
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,)))
104113
embedding_matrix[vocab[word]] = vector

fastNLP/models/biaffine_parser.py

+66-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import sys, os
2-
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
31
import copy
42
import numpy as np
53
import torch
@@ -11,6 +9,9 @@
119
from fastNLP.modules.dropout import TimestepDropout
1210
from fastNLP.models.base_model import BaseModel
1311
from fastNLP.modules.utils import seq_mask
12+
from fastNLP.core.losses import LossFunc
13+
from fastNLP.core.metrics import MetricBase
14+
from fastNLP.core.utils import seq_lens_to_masks
1415

1516
def mst(scores):
1617
"""
@@ -121,9 +122,6 @@ class GraphParser(BaseModel):
121122
def __init__(self):
122123
super(GraphParser, self).__init__()
123124

124-
def forward(self, x):
125-
raise NotImplementedError
126-
127125
def _greedy_decoder(self, arc_matrix, mask=None):
128126
_, seq_len, _ = arc_matrix.shape
129127
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
@@ -202,14 +200,14 @@ def __init__(self,
202200
word_emb_dim,
203201
pos_vocab_size,
204202
pos_emb_dim,
205-
word_hid_dim,
206-
pos_hid_dim,
207-
rnn_layers,
208-
rnn_hidden_size,
209-
arc_mlp_size,
210-
label_mlp_size,
211203
num_label,
212-
dropout,
204+
word_hid_dim=100,
205+
pos_hid_dim=100,
206+
rnn_layers=1,
207+
rnn_hidden_size=200,
208+
arc_mlp_size=100,
209+
label_mlp_size=100,
210+
dropout=0.3,
213211
use_var_lstm=False,
214212
use_greedy_infer=False):
215213

@@ -267,11 +265,11 @@ def reset_parameters(self):
267265
for p in m.parameters():
268266
nn.init.normal_(p, 0, 0.1)
269267

270-
def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
268+
def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None):
271269
"""
272270
:param word_seq: [batch_size, seq_len] sequence of word's indices
273271
:param pos_seq: [batch_size, seq_len] sequence of word's indices
274-
:param word_seq_origin_len: [batch_size, seq_len] sequence of length masks
272+
:param seq_lens: [batch_size, seq_len] sequence of length masks
275273
:param gold_heads: [batch_size, seq_len] sequence of golden heads
276274
:return dict: parsing results
277275
arc_pred: [batch_size, seq_len, seq_len]
@@ -283,12 +281,12 @@ def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
283281
device = self.parameters().__next__().device
284282
word_seq = word_seq.long().to(device)
285283
pos_seq = pos_seq.long().to(device)
286-
word_seq_origin_len = word_seq_origin_len.long().to(device).view(-1)
284+
seq_lens = seq_lens.long().to(device).view(-1)
287285
batch_size, seq_len = word_seq.shape
288286
# print('forward {} {}'.format(batch_size, seq_len))
289287

290288
# get sequence mask
291-
mask = seq_mask(word_seq_origin_len, seq_len).long()
289+
mask = seq_mask(seq_lens, seq_len).long()
292290

293291
word = self.normal_dropout(self.word_embedding(word_seq)) # [N,L] -> [N,L,C_0]
294292
pos = self.normal_dropout(self.pos_embedding(pos_seq)) # [N,L] -> [N,L,C_1]
@@ -298,7 +296,7 @@ def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
298296
del word, pos
299297

300298
# lstm, extract features
301-
sort_lens, sort_idx = torch.sort(word_seq_origin_len, dim=0, descending=True)
299+
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True)
302300
x = x[sort_idx]
303301
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True)
304302
feat, _ = self.lstm(x) # -> [N,L,C]
@@ -342,14 +340,15 @@ def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
342340
res_dict['head_pred'] = head_pred
343341
return res_dict
344342

345-
def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
343+
@staticmethod
344+
def loss(arc_pred, label_pred, arc_true, label_true, mask):
346345
"""
347346
Compute loss.
348347
349348
:param arc_pred: [batch_size, seq_len, seq_len]
350349
:param label_pred: [batch_size, seq_len, n_tags]
351-
:param head_indices: [batch_size, seq_len]
352-
:param head_labels: [batch_size, seq_len]
350+
:param arc_true: [batch_size, seq_len]
351+
:param label_true: [batch_size, seq_len]
353352
:param mask: [batch_size, seq_len]
354353
:return: loss value
355354
"""
@@ -362,8 +361,8 @@ def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
362361
label_logits = F.log_softmax(label_pred, dim=2)
363362
batch_index = torch.arange(batch_size, device=arc_logits.device, dtype=torch.long).unsqueeze(1)
364363
child_index = torch.arange(seq_len, device=arc_logits.device, dtype=torch.long).unsqueeze(0)
365-
arc_loss = arc_logits[batch_index, child_index, head_indices]
366-
label_loss = label_logits[batch_index, child_index, head_labels]
364+
arc_loss = arc_logits[batch_index, child_index, arc_true]
365+
label_loss = label_logits[batch_index, child_index, label_true]
367366

368367
arc_loss = arc_loss[:, 1:]
369368
label_loss = label_loss[:, 1:]
@@ -373,19 +372,58 @@ def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
373372
label_nll = -(label_loss*float_mask).mean()
374373
return arc_nll + label_nll
375374

376-
def predict(self, word_seq, pos_seq, word_seq_origin_len):
375+
def predict(self, word_seq, pos_seq, seq_lens):
377376
"""
378377
379378
:param word_seq:
380379
:param pos_seq:
381-
:param word_seq_origin_len:
382-
:return: head_pred: [B, L]
380+
:param seq_lens:
381+
:return: arc_pred: [B, L]
383382
label_pred: [B, L]
384-
seq_len: [B,]
385383
"""
386-
res = self(word_seq, pos_seq, word_seq_origin_len)
384+
res = self(word_seq, pos_seq, seq_lens)
387385
output = {}
388-
output['head_pred'] = res.pop('head_pred')
386+
output['arc_pred'] = res.pop('head_pred')
389387
_, label_pred = res.pop('label_pred').max(2)
390388
output['label_pred'] = label_pred
391389
return output
390+
391+
392+
class ParserLoss(LossFunc):
393+
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None):
394+
super(ParserLoss, self).__init__(BiaffineParser.loss,
395+
arc_pred=arc_pred,
396+
label_pred=label_pred,
397+
arc_true=arc_true,
398+
label_true=label_true)
399+
400+
401+
class ParserMetric(MetricBase):
402+
def __init__(self, arc_pred=None, label_pred=None,
403+
arc_true=None, label_true=None, seq_lens=None):
404+
super().__init__()
405+
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred,
406+
arc_true=arc_true, label_true=label_true,
407+
seq_lens=seq_lens)
408+
self.num_arc = 0
409+
self.num_label = 0
410+
self.num_sample = 0
411+
412+
def get_metric(self, reset=True):
413+
res = {'UAS': self.num_arc*1.0 / self.num_sample, 'LAS': self.num_label*1.0 / self.num_sample}
414+
if reset:
415+
self.num_sample = self.num_label = self.num_arc = 0
416+
return res
417+
418+
def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_lens=None):
419+
"""Evaluate the performance of prediction.
420+
"""
421+
if seq_lens is None:
422+
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long)
423+
else:
424+
seq_mask = seq_lens_to_masks(seq_lens, float=False).long()
425+
head_pred_correct = (arc_pred == arc_true).long() * seq_mask
426+
label_pred_correct = (label_pred == label_true).long() * head_pred_correct
427+
self.num_arc += head_pred_correct.sum().item()
428+
self.num_label += label_pred_correct.sum().item()
429+
self.num_sample += seq_mask.sum().item()

test/models/test_biaffine_parser.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric
2+
import fastNLP
3+
4+
import unittest
5+
6+
data_file = """
7+
1 The _ DET DT _ 3 det _ _
8+
2 new _ ADJ JJ _ 3 amod _ _
9+
3 rate _ NOUN NN _ 6 nsubj _ _
10+
4 will _ AUX MD _ 6 aux _ _
11+
5 be _ VERB VB _ 6 cop _ _
12+
6 payable _ ADJ JJ _ 0 root _ _
13+
9 cents _ NOUN NNS _ 4 nmod _ _
14+
10 from _ ADP IN _ 12 case _ _
15+
11 seven _ NUM CD _ 12 nummod _ _
16+
12 cents _ NOUN NNS _ 4 nmod _ _
17+
13 a _ DET DT _ 14 det _ _
18+
14 share _ NOUN NN _ 12 nmod:npmod _ _
19+
15 . _ PUNCT . _ 4 punct _ _
20+
21+
1 The _ DET DT _ 3 det _ _
22+
2 new _ ADJ JJ _ 3 amod _ _
23+
3 rate _ NOUN NN _ 6 nsubj _ _
24+
4 will _ AUX MD _ 6 aux _ _
25+
5 be _ VERB VB _ 6 cop _ _
26+
6 payable _ ADJ JJ _ 0 root _ _
27+
7 Feb. _ PROPN NNP _ 6 nmod:tmod _ _
28+
8 15 _ NUM CD _ 7 nummod _ _
29+
9 . _ PUNCT . _ 6 punct _ _
30+
31+
1 A _ DET DT _ 3 det _ _
32+
2 record _ NOUN NN _ 3 compound _ _
33+
3 date _ NOUN NN _ 7 nsubjpass _ _
34+
4 has _ AUX VBZ _ 7 aux _ _
35+
5 n't _ PART RB _ 7 neg _ _
36+
6 been _ AUX VBN _ 7 auxpass _ _
37+
7 set _ VERB VBN _ 0 root _ _
38+
8 . _ PUNCT . _ 7 punct _ _
39+
40+
"""
41+
42+
def init_data():
43+
ds = fastNLP.DataSet()
44+
v = {'word_seq': fastNLP.Vocabulary(),
45+
'pos_seq': fastNLP.Vocabulary(),
46+
'label_true': fastNLP.Vocabulary()}
47+
data = []
48+
for line in data_file.split('\n'):
49+
line = line.split()
50+
if len(line) == 0 and len(data) > 0:
51+
data = list(zip(*data))
52+
ds.append(fastNLP.Instance(word_seq=data[1],
53+
pos_seq=data[4],
54+
arc_true=data[6],
55+
label_true=data[7]))
56+
data = []
57+
elif len(line) > 0:
58+
data.append(line)
59+
60+
for name in ['word_seq', 'pos_seq', 'label_true']:
61+
ds.apply(lambda x: ['<st>']+list(x[name])+['<ed>'], new_field_name=name)
62+
ds.apply(lambda x: v[name].add_word_lst(x[name]))
63+
64+
for name in ['word_seq', 'pos_seq', 'label_true']:
65+
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name)
66+
67+
ds.apply(lambda x: [0]+list(map(int, x['arc_true']))+[1], new_field_name='arc_true')
68+
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens')
69+
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True)
70+
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True)
71+
return ds, v['word_seq'], v['pos_seq'], v['label_true']
72+
73+
class TestBiaffineParser(unittest.TestCase):
74+
def test_train(self):
75+
ds, v1, v2, v3 = init_data()
76+
model = BiaffineParser(word_vocab_size=len(v1), word_emb_dim=30,
77+
pos_vocab_size=len(v2), pos_emb_dim=30,
78+
num_label=len(v3))
79+
trainer = fastNLP.Trainer(model=model, train_data=ds, dev_data=ds,
80+
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
81+
n_epochs=10, use_cuda=False, use_tqdm=False)
82+
trainer.train(load_best_model=False)

0 commit comments

Comments
 (0)