9
9
10
10
from fastNLP .api .utils import load_url
11
11
from fastNLP .api .processor import ModelProcessor
12
- from reproduction .chinese_word_segment .cws_io .cws_reader import ConllCWSReader
13
- from reproduction .pos_tag_model .pos_reader import ZhConllPOSReader
14
- from reproduction .Biaffine_parser .util import ConllxDataLoader , add_seg_tag
12
+ from fastNLP .io .dataset_loader import ConllCWSReader , ConllxDataLoader
15
13
from fastNLP .core .instance import Instance
16
14
from fastNLP .api .pipeline import Pipeline
17
15
from fastNLP .core .metrics import SpanFPreRecMetric
18
16
from fastNLP .api .processor import IndexerProcessor
19
17
20
18
# TODO add pretrain urls
21
19
model_urls = {
22
- "cws" : "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899 .pkl" ,
23
- "pos" : "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5 .pkl" ,
24
- "parser" : "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c .pkl"
20
+ "cws" : "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656 .pkl" ,
21
+ "pos" : "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435 .pkl" ,
22
+ "parser" : "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0 .pkl"
25
23
}
26
24
27
25
@@ -31,6 +29,16 @@ def __init__(self):
31
29
self ._dict = None
32
30
33
31
def predict (self , * args , ** kwargs ):
32
+ """Do prediction for the given input.
33
+ """
34
+ raise NotImplementedError
35
+
36
+ def test (self , file_path ):
37
+ """Test performance over the given data set.
38
+
39
+ :param str file_path:
40
+ :return: a dictionary of metric values
41
+ """
34
42
raise NotImplementedError
35
43
36
44
def load (self , path , device ):
@@ -69,12 +77,11 @@ def predict(self, content):
69
77
if not hasattr (self , "pipeline" ):
70
78
raise ValueError ("You have to load model first." )
71
79
72
- sentence_list = []
80
+ sentence_list = content
73
81
# 1. 检查sentence的类型
74
- if isinstance (content , str ):
75
- sentence_list .append (content )
76
- elif isinstance (content , list ):
77
- sentence_list = content
82
+ for sentence in sentence_list :
83
+ if not all ((type (obj ) == str for obj in sentence )):
84
+ raise ValueError ("Input must be list of list of string." )
78
85
79
86
# 2. 组建dataset
80
87
dataset = DataSet ()
@@ -83,36 +90,28 @@ def predict(self, content):
83
90
# 3. 使用pipeline
84
91
self .pipeline (dataset )
85
92
86
- def decode_tags (ins ):
87
- pred_tags = ins ["tag" ]
88
- chars = ins ["words" ]
89
- words = []
90
- start_idx = 0
91
- for idx , tag in enumerate (pred_tags ):
92
- if tag [0 ] == "S" :
93
- words .append (chars [start_idx :idx + 1 ] + "/" + tag [2 :])
94
- start_idx = idx + 1
95
- elif tag [0 ] == "E" :
96
- words .append ("" .join (chars [start_idx :idx + 1 ]) + "/" + tag [2 :])
97
- start_idx = idx + 1
98
- return words
99
-
100
- dataset .apply (decode_tags , new_field_name = "tag_output" )
101
-
102
- output = dataset .field_arrays ["tag_output" ].content
93
+ def merge_tag (words_list , tags_list ):
94
+ rtn = []
95
+ for words , tags in zip (words_list , tags_list ):
96
+ rtn .append ([w + "/" + t for w , t in zip (words , tags )])
97
+ return rtn
98
+
99
+ output = dataset .field_arrays ["tag" ].content
103
100
if isinstance (content , str ):
104
101
return output [0 ]
105
102
elif isinstance (content , list ):
106
- return output
103
+ return merge_tag ( content , output )
107
104
108
105
def test (self , file_path ):
109
- test_data = ZhConllPOSReader ().load (file_path )
106
+ test_data = ConllxDataLoader ().load (file_path )
110
107
111
- tag_vocab = self ._dict ["tag_vocab" ]
112
- pipeline = self ._dict ["pipeline" ]
108
+ save_dict = self ._dict
109
+ tag_vocab = save_dict ["tag_vocab" ]
110
+ pipeline = save_dict ["pipeline" ]
113
111
index_tag = IndexerProcessor (vocab = tag_vocab , field_name = "tag" , new_added_field_name = "truth" , is_input = False )
114
112
pipeline .pipeline = [index_tag ] + pipeline .pipeline
115
113
114
+ test_data .rename_field ("pos_tags" , "tag" )
116
115
pipeline (test_data )
117
116
test_data .set_target ("truth" )
118
117
prediction = test_data .field_arrays ["predict" ].content
@@ -226,7 +225,7 @@ def test(self, filepath):
226
225
rec = eval_res ['BMESF1PreRecMetric' ]['rec' ]
227
226
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
228
227
229
- return f1 , pre , rec
228
+ return { "F1" : f1 , "precision" : pre , "recall" : rec }
230
229
231
230
232
231
class Parser (API ):
@@ -251,6 +250,7 @@ def predict(self, content):
251
250
dataset .add_field ('wp' , pos_out )
252
251
dataset .apply (lambda x : ['<BOS>' ] + [w .split ('/' )[0 ] for w in x ['wp' ]], new_field_name = 'words' )
253
252
dataset .apply (lambda x : ['<BOS>' ] + [w .split ('/' )[1 ] for w in x ['wp' ]], new_field_name = 'pos' )
253
+ dataset .rename_field ("words" , "raw_words" )
254
254
255
255
# 3. 使用pipeline
256
256
self .pipeline (dataset )
@@ -260,39 +260,82 @@ def predict(self, content):
260
260
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
261
261
return dataset .field_arrays ['output' ].content
262
262
263
- def test (self , filepath ):
264
- data = ConllxDataLoader ().load (filepath )
265
- ds = DataSet ()
266
- for ins1 , ins2 in zip (add_seg_tag (data ), data ):
267
- ds .append (Instance (words = ins1 [0 ], tag = ins1 [1 ],
268
- gold_words = ins2 [0 ], gold_pos = ins2 [1 ],
269
- gold_heads = ins2 [2 ], gold_head_tags = ins2 [3 ]))
263
+ def load_test_file (self , path ):
264
+ def get_one (sample ):
265
+ sample = list (map (list , zip (* sample )))
266
+ if len (sample ) == 0 :
267
+ return None
268
+ for w in sample [7 ]:
269
+ if w == '_' :
270
+ print ('Error Sample {}' .format (sample ))
271
+ return None
272
+ # return word_seq, pos_seq, head_seq, head_tag_seq
273
+ return sample [1 ], sample [3 ], list (map (int , sample [6 ])), sample [7 ]
274
+
275
+ datalist = []
276
+ with open (path , 'r' , encoding = 'utf-8' ) as f :
277
+ sample = []
278
+ for line in f :
279
+ if line .startswith ('\n ' ):
280
+ datalist .append (sample )
281
+ sample = []
282
+ elif line .startswith ('#' ):
283
+ continue
284
+ else :
285
+ sample .append (line .split ('\t ' ))
286
+ if len (sample ) > 0 :
287
+ datalist .append (sample )
288
+
289
+ data = [get_one (sample ) for sample in datalist ]
290
+ data_list = list (filter (lambda x : x is not None , data ))
291
+ return data_list
270
292
293
+ def test (self , filepath ):
294
+ data = self .load_test_file (filepath )
295
+
296
+ def convert (data ):
297
+ BOS = '<BOS>'
298
+ dataset = DataSet ()
299
+ for sample in data :
300
+ word_seq = [BOS ] + sample [0 ]
301
+ pos_seq = [BOS ] + sample [1 ]
302
+ heads = [0 ] + sample [2 ]
303
+ head_tags = [BOS ] + sample [3 ]
304
+ dataset .append (Instance (raw_words = word_seq ,
305
+ pos = pos_seq ,
306
+ gold_heads = heads ,
307
+ arc_true = heads ,
308
+ tags = head_tags ))
309
+ return dataset
310
+
311
+ ds = convert (data )
271
312
pp = self .pipeline
272
313
for p in pp :
273
314
if p .field_name == 'word_list' :
274
315
p .field_name = 'gold_words'
275
316
elif p .field_name == 'pos_list' :
276
317
p .field_name = 'gold_pos'
318
+ # ds.rename_field("words", "raw_words")
319
+ # ds.rename_field("tag", "pos")
277
320
pp (ds )
278
321
head_cor , label_cor , total = 0 , 0 , 0
279
322
for ins in ds :
280
323
head_gold = ins ['gold_heads' ]
281
- head_pred = ins ['heads ' ]
324
+ head_pred = ins ['arc_pred ' ]
282
325
length = len (head_gold )
283
326
total += length
284
327
for i in range (length ):
285
328
head_cor += 1 if head_pred [i ] == head_gold [i ] else 0
286
329
uas = head_cor / total
287
- print ('uas:{:.2f}' .format (uas ))
330
+ # print('uas:{:.2f}'.format(uas))
288
331
289
332
for p in pp :
290
333
if p .field_name == 'gold_words' :
291
334
p .field_name = 'word_list'
292
335
elif p .field_name == 'gold_pos' :
293
336
p .field_name = 'pos_list'
294
337
295
- return uas
338
+ return { "USA" : round ( uas , 5 )}
296
339
297
340
298
341
class Analyzer :
0 commit comments