Skip to content

Commit 13faa2b

Browse files
authored
Merge pull request #132 from FengZiYjun/v0.3.1
fastNLP V0.3.1
2 parents 3fa95b6 + b66d7b8 commit 13faa2b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+3957
-4832
lines changed

codecov.yml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
ignore:
2+
- "reproduction" # ignore folders and all its contents
3+
- "setup.py"
4+
- "docs"
5+
- "tutorials"

docs/source/tutorials/fastnlp_10tmin_tutorial.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
2-
fastNLP上手教程
1+
fastNLP 10分钟上手教程
32
===============
43

4+
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_10min_tutorial.ipynb
5+
56
fastNLP提供方便的数据预处理,训练和测试模型的功能
67

78
DataSet & Instance

docs/source/tutorials/fastnlp_1_minute_tutorial.rst

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
FastNLP 1分钟上手教程
33
=====================
44

5+
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_1min_tutorial.ipynb
6+
57
step 1
68
------
79

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
fastNLP 进阶教程
2+
===============
3+
4+
教程原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb
5+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
fastNLP 开发者指南
2+
===============
3+
4+
原文见 https://github.com/fastnlp/fastNLP/blob/master/tutorials/tutorial_for_developer.md
5+

docs/source/user/installation.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Installation
55
.. contents::
66
:local:
77

8+
Make sure your environment satisfies https://github.com/fastnlp/fastNLP/blob/master/requirements.txt .
89

910
Run the following commands to install fastNLP package:
1011

docs/source/user/quickstart.rst

+2
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ Quickstart
66

77
../tutorials/fastnlp_1_minute_tutorial
88
../tutorials/fastnlp_10tmin_tutorial
9+
../tutorials/fastnlp_advanced_tutorial
10+
../tutorials/fastnlp_developer_guide
911

fastNLP/api/README.md

+11-10
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,27 @@ print(cws.predict(text))
1818
# ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?']
1919
```
2020

21-
### 中文分词+词性标注
21+
### 词性标注
2222
```python
23-
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
24-
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
25-
'那么这款无人机到底有多厉害?']
23+
# 输入已分词序列
24+
text = [['编者', '按:', '7月', '12日', '', '英国', '航空', '航天', '系统', '公司', '公布', '', '', '公司',
25+
'研制', '', '第一款', '高科技', '隐形', '无人机', '雷电之神', ''],
26+
['那么', '', '', '无人机', '到底', '', '', '厉害', '']]
2627
from fastNLP.api import POS
2728
pos = POS(device='cpu')
2829
print(pos.predict(text))
29-
# [['编者/NN', '按/P', ':/PU', '7月/NT', '12日/NR', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一/OD', '款高/NN', '科技/NN', '隐形/NN', '无/VE', '人机/NN', '雷电/NN', '之/DEG', '神/NN', '。/PU'], ['这/DT', '款/NN', '飞行/VV', '从/P', '外型/NN', '上/LC', '来/MSP', '看/VV', '酷似/VV', '电影/NN', '中/LC', '的/DEG', '太空/NN', '飞行器/NN', ',/PU', '据/P', '英国/NR', '方面/NN', '介绍/VV', ',/PU', '可以/VV', '实现/VV', '洲际/NN', '远程/NN', '打击/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无/VE', '人机/NN', '到底/AD', '有/VE', '多/CD', '厉害/NN', '?/PU']]
30+
# [['编者/NN', '按:/NN', '7月/NT', '12日/NT', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一款/NN', '高科技/NN', '隐形/AD', '无人机/VV', '雷电之神/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无人机/VV', '到底/AD', '有/VE', '多/AD', '厉害/VA', '?/PU']]
3031
```
3132

32-
### 中文分词+词性标注+句法分析
33+
### 句法分析
3334
```python
34-
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
35-
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
36-
'那么这款无人机到底有多厉害?']
35+
text = [['编者', '按:', '7月', '12日', '', '英国', '航空', '航天', '系统', '公司', '公布', '', '', '公司',
36+
'研制', '', '第一款', '高科技', '隐形', '无人机', '雷电之神', ''],
37+
['那么', '', '', '无人机', '到底', '', '', '厉害', '']]
3738
from fastNLP.api import Parser
3839
parser = Parser(device='cpu')
3940
print(parser.predict(text))
40-
# [['12/nsubj', '12/prep', '2/punct', '5/nn', '2/pobj', '12/punct', '11/nn', '11/nn', '11/nn', '11/nn', '2/pobj', '0/root', '12/asp', '15/det', '16/nsubj', '21/rcmod', '16/cpm', '21/nummod', '21/nn', '21/nn', '22/top', '12/ccomp', '24/nn', '26/assmod', '24/assm', '22/dobj', '12/punct'], ['2/det', '8/xsubj', '8/mmod', '8/prep', '6/lobj', '4/plmod', '8/prtmod', '0/root', '8/ccomp', '11/lobj', '14/assmod', '11/assm', '14/nn', '9/dobj', '8/punct', '22/prep', '18/nn', '19/nsubj', '16/pccomp', '22/punct', '22/mmod', '8/dep', '25/nn', '25/nn', '22/dobj', '8/punct'], ['4/advmod', '3/det', '4/nsubj', '0/root', '4/dobj', '7/advmod', '4/conj', '9/nummod', '7/dobj', '4/punct']]
41+
# [['2/nn', '4/nn', '4/nn', '20/tmod', '11/punct', '10/nn', '10/nn', '10/nn', '10/nn', '11/nsubj', '20/dep', '11/asp', '14/det', '15/nsubj', '18/rcmod', '15/cpm', '18/nn', '11/dobj', '20/advmod', '0/root', '20/dobj', '20/punct'], ['4/advmod', '3/det', '8/xsubj', '8/dep', '8/advmod', '8/dep', '8/advmod', '0/root', '8/punct']]
4142
```
4243

4344
完整样例见`examples.py`

fastNLP/api/api.py

+86-43
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,17 @@
99

1010
from fastNLP.api.utils import load_url
1111
from fastNLP.api.processor import ModelProcessor
12-
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
13-
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
14-
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
12+
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader
1513
from fastNLP.core.instance import Instance
1614
from fastNLP.api.pipeline import Pipeline
1715
from fastNLP.core.metrics import SpanFPreRecMetric
1816
from fastNLP.api.processor import IndexerProcessor
1917

2018
# TODO add pretrain urls
2119
model_urls = {
22-
"cws": "http://123.206.98.91:8888/download/cws_crf_1_11-457fc899.pkl",
23-
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190108-f3c60ee5.pkl",
24-
"parser": "http://123.206.98.91:8888/download/biaffine_parser-3a2f052c.pkl"
20+
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl",
21+
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl",
22+
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl"
2523
}
2624

2725

@@ -31,6 +29,16 @@ def __init__(self):
3129
self._dict = None
3230

3331
def predict(self, *args, **kwargs):
32+
"""Do prediction for the given input.
33+
"""
34+
raise NotImplementedError
35+
36+
def test(self, file_path):
37+
"""Test performance over the given data set.
38+
39+
:param str file_path:
40+
:return: a dictionary of metric values
41+
"""
3442
raise NotImplementedError
3543

3644
def load(self, path, device):
@@ -69,12 +77,11 @@ def predict(self, content):
6977
if not hasattr(self, "pipeline"):
7078
raise ValueError("You have to load model first.")
7179

72-
sentence_list = []
80+
sentence_list = content
7381
# 1. 检查sentence的类型
74-
if isinstance(content, str):
75-
sentence_list.append(content)
76-
elif isinstance(content, list):
77-
sentence_list = content
82+
for sentence in sentence_list:
83+
if not all((type(obj) == str for obj in sentence)):
84+
raise ValueError("Input must be list of list of string.")
7885

7986
# 2. 组建dataset
8087
dataset = DataSet()
@@ -83,36 +90,28 @@ def predict(self, content):
8390
# 3. 使用pipeline
8491
self.pipeline(dataset)
8592

86-
def decode_tags(ins):
87-
pred_tags = ins["tag"]
88-
chars = ins["words"]
89-
words = []
90-
start_idx = 0
91-
for idx, tag in enumerate(pred_tags):
92-
if tag[0] == "S":
93-
words.append(chars[start_idx:idx + 1] + "/" + tag[2:])
94-
start_idx = idx + 1
95-
elif tag[0] == "E":
96-
words.append("".join(chars[start_idx:idx + 1]) + "/" + tag[2:])
97-
start_idx = idx + 1
98-
return words
99-
100-
dataset.apply(decode_tags, new_field_name="tag_output")
101-
102-
output = dataset.field_arrays["tag_output"].content
93+
def merge_tag(words_list, tags_list):
94+
rtn = []
95+
for words, tags in zip(words_list, tags_list):
96+
rtn.append([w + "/" + t for w, t in zip(words, tags)])
97+
return rtn
98+
99+
output = dataset.field_arrays["tag"].content
103100
if isinstance(content, str):
104101
return output[0]
105102
elif isinstance(content, list):
106-
return output
103+
return merge_tag(content, output)
107104

108105
def test(self, file_path):
109-
test_data = ZhConllPOSReader().load(file_path)
106+
test_data = ConllxDataLoader().load(file_path)
110107

111-
tag_vocab = self._dict["tag_vocab"]
112-
pipeline = self._dict["pipeline"]
108+
save_dict = self._dict
109+
tag_vocab = save_dict["tag_vocab"]
110+
pipeline = save_dict["pipeline"]
113111
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False)
114112
pipeline.pipeline = [index_tag] + pipeline.pipeline
115113

114+
test_data.rename_field("pos_tags", "tag")
116115
pipeline(test_data)
117116
test_data.set_target("truth")
118117
prediction = test_data.field_arrays["predict"].content
@@ -226,7 +225,7 @@ def test(self, filepath):
226225
rec = eval_res['BMESF1PreRecMetric']['rec']
227226
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
228227

229-
return f1, pre, rec
228+
return {"F1": f1, "precision": pre, "recall": rec}
230229

231230

232231
class Parser(API):
@@ -251,6 +250,7 @@ def predict(self, content):
251250
dataset.add_field('wp', pos_out)
252251
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words')
253252
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos')
253+
dataset.rename_field("words", "raw_words")
254254

255255
# 3. 使用pipeline
256256
self.pipeline(dataset)
@@ -260,39 +260,82 @@ def predict(self, content):
260260
# output like: [['2/top', '0/root', '4/nn', '2/dep']]
261261
return dataset.field_arrays['output'].content
262262

263-
def test(self, filepath):
264-
data = ConllxDataLoader().load(filepath)
265-
ds = DataSet()
266-
for ins1, ins2 in zip(add_seg_tag(data), data):
267-
ds.append(Instance(words=ins1[0], tag=ins1[1],
268-
gold_words=ins2[0], gold_pos=ins2[1],
269-
gold_heads=ins2[2], gold_head_tags=ins2[3]))
263+
def load_test_file(self, path):
264+
def get_one(sample):
265+
sample = list(map(list, zip(*sample)))
266+
if len(sample) == 0:
267+
return None
268+
for w in sample[7]:
269+
if w == '_':
270+
print('Error Sample {}'.format(sample))
271+
return None
272+
# return word_seq, pos_seq, head_seq, head_tag_seq
273+
return sample[1], sample[3], list(map(int, sample[6])), sample[7]
274+
275+
datalist = []
276+
with open(path, 'r', encoding='utf-8') as f:
277+
sample = []
278+
for line in f:
279+
if line.startswith('\n'):
280+
datalist.append(sample)
281+
sample = []
282+
elif line.startswith('#'):
283+
continue
284+
else:
285+
sample.append(line.split('\t'))
286+
if len(sample) > 0:
287+
datalist.append(sample)
288+
289+
data = [get_one(sample) for sample in datalist]
290+
data_list = list(filter(lambda x: x is not None, data))
291+
return data_list
270292

293+
def test(self, filepath):
294+
data = self.load_test_file(filepath)
295+
296+
def convert(data):
297+
BOS = '<BOS>'
298+
dataset = DataSet()
299+
for sample in data:
300+
word_seq = [BOS] + sample[0]
301+
pos_seq = [BOS] + sample[1]
302+
heads = [0] + sample[2]
303+
head_tags = [BOS] + sample[3]
304+
dataset.append(Instance(raw_words=word_seq,
305+
pos=pos_seq,
306+
gold_heads=heads,
307+
arc_true=heads,
308+
tags=head_tags))
309+
return dataset
310+
311+
ds = convert(data)
271312
pp = self.pipeline
272313
for p in pp:
273314
if p.field_name == 'word_list':
274315
p.field_name = 'gold_words'
275316
elif p.field_name == 'pos_list':
276317
p.field_name = 'gold_pos'
318+
# ds.rename_field("words", "raw_words")
319+
# ds.rename_field("tag", "pos")
277320
pp(ds)
278321
head_cor, label_cor, total = 0, 0, 0
279322
for ins in ds:
280323
head_gold = ins['gold_heads']
281-
head_pred = ins['heads']
324+
head_pred = ins['arc_pred']
282325
length = len(head_gold)
283326
total += length
284327
for i in range(length):
285328
head_cor += 1 if head_pred[i] == head_gold[i] else 0
286329
uas = head_cor / total
287-
print('uas:{:.2f}'.format(uas))
330+
# print('uas:{:.2f}'.format(uas))
288331

289332
for p in pp:
290333
if p.field_name == 'gold_words':
291334
p.field_name = 'word_list'
292335
elif p.field_name == 'gold_pos':
293336
p.field_name = 'pos_list'
294337

295-
return uas
338+
return {"USA": round(uas, 5)}
296339

297340

298341
class Analyzer:

fastNLP/api/examples.py

+27
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,42 @@ def chinese_word_segmentation():
1515
print(cws.predict(text))
1616

1717

18+
def chinese_word_segmentation_test():
19+
cws = CWS(device='cpu')
20+
print(cws.test("../../test/data_for_tests/zh_sample.conllx"))
21+
22+
1823
def pos_tagging():
24+
# 输入已分词序列
25+
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司',
26+
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'],
27+
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']]
1928
pos = POS(device='cpu')
2029
print(pos.predict(text))
2130

2231

32+
def pos_tagging_test():
33+
pos = POS(device='cpu')
34+
print(pos.test("../../test/data_for_tests/zh_sample.conllx"))
35+
36+
2337
def syntactic_parsing():
38+
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司',
39+
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'],
40+
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']]
2441
parser = Parser(device='cpu')
2542
print(parser.predict(text))
2643

2744

45+
def syntactic_parsing_test():
46+
parser = Parser(device='cpu')
47+
print(parser.test("../../test/data_for_tests/zh_sample.conllx"))
48+
49+
2850
if __name__ == "__main__":
51+
# chinese_word_segmentation()
52+
# chinese_word_segmentation_test()
53+
# pos_tagging()
54+
# pos_tagging_test()
2955
syntactic_parsing()
56+
# syntactic_parsing_test()

0 commit comments

Comments
 (0)