1
- import sys , os
2
- sys .path .append (os .path .join (os .path .dirname (__file__ ), '../..' ))
3
1
import copy
4
2
import numpy as np
5
3
import torch
11
9
from fastNLP .modules .dropout import TimestepDropout
12
10
from fastNLP .models .base_model import BaseModel
13
11
from fastNLP .modules .utils import seq_mask
12
+ from fastNLP .core .losses import LossFunc
13
+ from fastNLP .core .metrics import MetricBase
14
+ from fastNLP .core .utils import seq_lens_to_masks
14
15
15
16
def mst (scores ):
16
17
"""
@@ -121,9 +122,6 @@ class GraphParser(BaseModel):
121
122
def __init__ (self ):
122
123
super (GraphParser , self ).__init__ ()
123
124
124
- def forward (self , x ):
125
- raise NotImplementedError
126
-
127
125
def _greedy_decoder (self , arc_matrix , mask = None ):
128
126
_ , seq_len , _ = arc_matrix .shape
129
127
matrix = arc_matrix + torch .diag (arc_matrix .new (seq_len ).fill_ (- np .inf ))
@@ -202,14 +200,14 @@ def __init__(self,
202
200
word_emb_dim ,
203
201
pos_vocab_size ,
204
202
pos_emb_dim ,
205
- word_hid_dim ,
206
- pos_hid_dim ,
207
- rnn_layers ,
208
- rnn_hidden_size ,
209
- arc_mlp_size ,
210
- label_mlp_size ,
211
203
num_label ,
212
- dropout ,
204
+ word_hid_dim = 100 ,
205
+ pos_hid_dim = 100 ,
206
+ rnn_layers = 1 ,
207
+ rnn_hidden_size = 200 ,
208
+ arc_mlp_size = 100 ,
209
+ label_mlp_size = 100 ,
210
+ dropout = 0.3 ,
213
211
use_var_lstm = False ,
214
212
use_greedy_infer = False ):
215
213
@@ -267,11 +265,11 @@ def reset_parameters(self):
267
265
for p in m .parameters ():
268
266
nn .init .normal_ (p , 0 , 0.1 )
269
267
270
- def forward (self , word_seq , pos_seq , word_seq_origin_len , gold_heads = None , ** _ ):
268
+ def forward (self , word_seq , pos_seq , seq_lens , gold_heads = None ):
271
269
"""
272
270
:param word_seq: [batch_size, seq_len] sequence of word's indices
273
271
:param pos_seq: [batch_size, seq_len] sequence of word's indices
274
- :param word_seq_origin_len : [batch_size, seq_len] sequence of length masks
272
+ :param seq_lens : [batch_size, seq_len] sequence of length masks
275
273
:param gold_heads: [batch_size, seq_len] sequence of golden heads
276
274
:return dict: parsing results
277
275
arc_pred: [batch_size, seq_len, seq_len]
@@ -283,12 +281,12 @@ def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
283
281
device = self .parameters ().__next__ ().device
284
282
word_seq = word_seq .long ().to (device )
285
283
pos_seq = pos_seq .long ().to (device )
286
- word_seq_origin_len = word_seq_origin_len .long ().to (device ).view (- 1 )
284
+ seq_lens = seq_lens .long ().to (device ).view (- 1 )
287
285
batch_size , seq_len = word_seq .shape
288
286
# print('forward {} {}'.format(batch_size, seq_len))
289
287
290
288
# get sequence mask
291
- mask = seq_mask (word_seq_origin_len , seq_len ).long ()
289
+ mask = seq_mask (seq_lens , seq_len ).long ()
292
290
293
291
word = self .normal_dropout (self .word_embedding (word_seq )) # [N,L] -> [N,L,C_0]
294
292
pos = self .normal_dropout (self .pos_embedding (pos_seq )) # [N,L] -> [N,L,C_1]
@@ -298,7 +296,7 @@ def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
298
296
del word , pos
299
297
300
298
# lstm, extract features
301
- sort_lens , sort_idx = torch .sort (word_seq_origin_len , dim = 0 , descending = True )
299
+ sort_lens , sort_idx = torch .sort (seq_lens , dim = 0 , descending = True )
302
300
x = x [sort_idx ]
303
301
x = nn .utils .rnn .pack_padded_sequence (x , sort_lens , batch_first = True )
304
302
feat , _ = self .lstm (x ) # -> [N,L,C]
@@ -342,14 +340,15 @@ def forward(self, word_seq, pos_seq, word_seq_origin_len, gold_heads=None, **_):
342
340
res_dict ['head_pred' ] = head_pred
343
341
return res_dict
344
342
345
- def loss (self , arc_pred , label_pred , head_indices , head_labels , mask , ** _ ):
343
+ @staticmethod
344
+ def loss (arc_pred , label_pred , arc_true , label_true , mask ):
346
345
"""
347
346
Compute loss.
348
347
349
348
:param arc_pred: [batch_size, seq_len, seq_len]
350
349
:param label_pred: [batch_size, seq_len, n_tags]
351
- :param head_indices : [batch_size, seq_len]
352
- :param head_labels : [batch_size, seq_len]
350
+ :param arc_true : [batch_size, seq_len]
351
+ :param label_true : [batch_size, seq_len]
353
352
:param mask: [batch_size, seq_len]
354
353
:return: loss value
355
354
"""
@@ -362,8 +361,8 @@ def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
362
361
label_logits = F .log_softmax (label_pred , dim = 2 )
363
362
batch_index = torch .arange (batch_size , device = arc_logits .device , dtype = torch .long ).unsqueeze (1 )
364
363
child_index = torch .arange (seq_len , device = arc_logits .device , dtype = torch .long ).unsqueeze (0 )
365
- arc_loss = arc_logits [batch_index , child_index , head_indices ]
366
- label_loss = label_logits [batch_index , child_index , head_labels ]
364
+ arc_loss = arc_logits [batch_index , child_index , arc_true ]
365
+ label_loss = label_logits [batch_index , child_index , label_true ]
367
366
368
367
arc_loss = arc_loss [:, 1 :]
369
368
label_loss = label_loss [:, 1 :]
@@ -373,19 +372,58 @@ def loss(self, arc_pred, label_pred, head_indices, head_labels, mask, **_):
373
372
label_nll = - (label_loss * float_mask ).mean ()
374
373
return arc_nll + label_nll
375
374
376
- def predict (self , word_seq , pos_seq , word_seq_origin_len ):
375
+ def predict (self , word_seq , pos_seq , seq_lens ):
377
376
"""
378
377
379
378
:param word_seq:
380
379
:param pos_seq:
381
- :param word_seq_origin_len :
382
- :return: head_pred : [B, L]
380
+ :param seq_lens :
381
+ :return: arc_pred : [B, L]
383
382
label_pred: [B, L]
384
- seq_len: [B,]
385
383
"""
386
- res = self (word_seq , pos_seq , word_seq_origin_len )
384
+ res = self (word_seq , pos_seq , seq_lens )
387
385
output = {}
388
- output ['head_pred ' ] = res .pop ('head_pred' )
386
+ output ['arc_pred ' ] = res .pop ('head_pred' )
389
387
_ , label_pred = res .pop ('label_pred' ).max (2 )
390
388
output ['label_pred' ] = label_pred
391
389
return output
390
+
391
+
392
+ class ParserLoss (LossFunc ):
393
+ def __init__ (self , arc_pred = None , label_pred = None , arc_true = None , label_true = None ):
394
+ super (ParserLoss , self ).__init__ (BiaffineParser .loss ,
395
+ arc_pred = arc_pred ,
396
+ label_pred = label_pred ,
397
+ arc_true = arc_true ,
398
+ label_true = label_true )
399
+
400
+
401
+ class ParserMetric (MetricBase ):
402
+ def __init__ (self , arc_pred = None , label_pred = None ,
403
+ arc_true = None , label_true = None , seq_lens = None ):
404
+ super ().__init__ ()
405
+ self ._init_param_map (arc_pred = arc_pred , label_pred = label_pred ,
406
+ arc_true = arc_true , label_true = label_true ,
407
+ seq_lens = seq_lens )
408
+ self .num_arc = 0
409
+ self .num_label = 0
410
+ self .num_sample = 0
411
+
412
+ def get_metric (self , reset = True ):
413
+ res = {'UAS' : self .num_arc * 1.0 / self .num_sample , 'LAS' : self .num_label * 1.0 / self .num_sample }
414
+ if reset :
415
+ self .num_sample = self .num_label = self .num_arc = 0
416
+ return res
417
+
418
+ def evaluate (self , arc_pred , label_pred , arc_true , label_true , seq_lens = None ):
419
+ """Evaluate the performance of prediction.
420
+ """
421
+ if seq_lens is None :
422
+ seq_mask = arc_pred .new_ones (arc_pred .size (), dtype = torch .long )
423
+ else :
424
+ seq_mask = seq_lens_to_masks (seq_lens , float = False ).long ()
425
+ head_pred_correct = (arc_pred == arc_true ).long () * seq_mask
426
+ label_pred_correct = (label_pred == label_true ).long () * head_pred_correct
427
+ self .num_arc += head_pred_correct .sum ().item ()
428
+ self .num_label += label_pred_correct .sum ().item ()
429
+ self .num_sample += seq_mask .sum ().item ()
0 commit comments