Skip to content

Commit 568dc66

Browse files
authored
Merge pull request #241 from choosewhatulike/dev0.5.0
[add] reproduction of multi-criteria cws
2 parents 16e9e3a + d5bea4a commit 568dc66

File tree

12 files changed

+1882
-0
lines changed

12 files changed

+1882
-0
lines changed
+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
3+
# Multi-Criteria-CWS
4+
5+
An implementation of [Multi-Criteria Chinese Word Segmentation with Transformer](http://arxiv.org/abs/1906.12035) with fastNLP.
6+
7+
## Dataset
8+
### Overview
9+
We use the same datasets listed in paper.
10+
- sighan2005
11+
- pku
12+
- msr
13+
- as
14+
- cityu
15+
- sighan2008
16+
- ctb
17+
- ckip
18+
- cityu (combined with data in sighan2005)
19+
- ncc
20+
- sxu
21+
22+
### Preprocess
23+
First, download OpenCC to convert between Traditional Chinese and Simplified Chinese.
24+
``` shell
25+
pip install opencc-python-reimplemented
26+
```
27+
Then, set a path to save processed data, and run the shell script to process the data.
28+
```shell
29+
export DATA_DIR=path/to/processed-data
30+
bash make_data.sh path/to/sighan2005 path/to/sighan2008
31+
```
32+
It would take a few minutes to finish the process.
33+
34+
## Model
35+
We use transformer to build the model, as described in paper.
36+
37+
## Train
38+
Finally, to train the model, run the shell script.
39+
The `train.sh` takes one argument, the GPU-IDs to use, for example:
40+
``` shell
41+
bash train.sh 0,1
42+
```
43+
This command use GPUs with ID 0 and 1.
44+
45+
Note: Please refer to the paper for details of hyper-parameters. And modify the settings in `train.sh` to match your experiment environment.
46+
47+
Type
48+
``` shell
49+
python main.py --help
50+
```
51+
to learn all arguments to be specified in training.
52+
53+
## Performance
54+
55+
Results on the test sets of eight CWS datasets with multi-criteria learning.
56+
57+
| Dataset | MSRA | AS | PKU | CTB | CKIP | CITYU | NCC | SXU | Avg. |
58+
| -------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
59+
| Original paper | 98.05 | 96.44 | 96.41 | 96.99 | 96.51 | 96.91 | 96.04 | 97.61 | 96.87 |
60+
| Ours | 96.92 | 95.71 | 95.65 | 95.96 | 96.00 | 96.09 | 94.61 | 96.64 | 95.95 |
61+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import os
2+
import re
3+
import argparse
4+
from opencc import OpenCC
5+
6+
cc = OpenCC("t2s")
7+
8+
from utils import make_sure_path_exists, append_tags
9+
10+
sighan05_root = ""
11+
sighan08_root = ""
12+
data_path = ""
13+
14+
E_pun = u",.!?[]()<>\"\"'',"
15+
C_pun = u",。!?【】()《》“”‘’、"
16+
Table = {ord(f): ord(t) for f, t in zip(C_pun, E_pun)}
17+
Table[12288] = 32 # 全半角空格
18+
19+
20+
def C_trans_to_E(string):
21+
return string.translate(Table)
22+
23+
24+
def normalize(ustring):
25+
"""全角转半角"""
26+
rstring = ""
27+
for uchar in ustring:
28+
inside_code = ord(uchar)
29+
if inside_code == 12288: # 全角空格直接转换
30+
inside_code = 32
31+
elif 65281 <= inside_code <= 65374: # 全角字符(除空格)根据关系转化
32+
inside_code -= 65248
33+
34+
rstring += chr(inside_code)
35+
return rstring
36+
37+
38+
def preprocess(text):
39+
rNUM = u"(-|\+)?\d+((\.|·)\d+)?%?"
40+
rENG = u"[A-Za-z_]+.*"
41+
sent = normalize(C_trans_to_E(text.strip())).split()
42+
new_sent = []
43+
for word in sent:
44+
word = re.sub(u"\s+", "", word, flags=re.U)
45+
word = re.sub(rNUM, u"0", word, flags=re.U)
46+
word = re.sub(rENG, u"X", word)
47+
new_sent.append(word)
48+
return new_sent
49+
50+
51+
def to_sentence_list(text, split_long_sentence=False):
52+
text = preprocess(text)
53+
delimiter = set()
54+
delimiter.update("。!?:;…、,(),;!?、,\"'")
55+
delimiter.add("……")
56+
sent_list = []
57+
sent = []
58+
sent_len = 0
59+
for word in text:
60+
sent.append(word)
61+
sent_len += len(word)
62+
if word in delimiter or (split_long_sentence and sent_len >= 50):
63+
sent_list.append(sent)
64+
sent = []
65+
sent_len = 0
66+
67+
if len(sent) > 0:
68+
sent_list.append(sent)
69+
70+
return sent_list
71+
72+
73+
def is_traditional(dataset):
74+
return dataset in ["as", "cityu", "ckip"]
75+
76+
77+
def convert_file(
78+
src, des, need_cc=False, split_long_sentence=False, encode="utf-8-sig"
79+
):
80+
with open(src, encoding=encode) as src, open(des, "w", encoding="utf-8") as des:
81+
for line in src:
82+
for sent in to_sentence_list(line, split_long_sentence):
83+
line = " ".join(sent) + "\n"
84+
if need_cc:
85+
line = cc.convert(line)
86+
des.write(line)
87+
# if len(''.join(sent)) > 200:
88+
# print(' '.join(sent))
89+
90+
91+
def split_train_dev(dataset):
92+
root = data_path + "/" + dataset + "/raw/"
93+
with open(root + "train-all.txt", encoding="UTF-8") as src, open(
94+
root + "train.txt", "w", encoding="UTF-8"
95+
) as train, open(root + "dev.txt", "w", encoding="UTF-8") as dev:
96+
lines = src.readlines()
97+
idx = int(len(lines) * 0.9)
98+
for line in lines[:idx]:
99+
train.write(line)
100+
for line in lines[idx:]:
101+
dev.write(line)
102+
103+
104+
def combine_files(one, two, out):
105+
if os.path.exists(out):
106+
os.remove(out)
107+
with open(one, encoding="utf-8") as one, open(two, encoding="utf-8") as two, open(
108+
out, "a", encoding="utf-8"
109+
) as out:
110+
for line in one:
111+
out.write(line)
112+
for line in two:
113+
out.write(line)
114+
115+
116+
def bmes_tag(input_file, output_file):
117+
with open(input_file, encoding="utf-8") as input_data, open(
118+
output_file, "w", encoding="utf-8"
119+
) as output_data:
120+
for line in input_data:
121+
word_list = line.strip().split()
122+
for word in word_list:
123+
if len(word) == 1 or (
124+
len(word) > 2 and word[0] == "<" and word[-1] == ">"
125+
):
126+
output_data.write(word + "\tS\n")
127+
else:
128+
output_data.write(word[0] + "\tB\n")
129+
for w in word[1 : len(word) - 1]:
130+
output_data.write(w + "\tM\n")
131+
output_data.write(word[len(word) - 1] + "\tE\n")
132+
output_data.write("\n")
133+
134+
135+
def make_bmes(dataset="pku"):
136+
path = data_path + "/" + dataset + "/"
137+
make_sure_path_exists(path + "bmes")
138+
bmes_tag(path + "raw/train.txt", path + "bmes/train.txt")
139+
bmes_tag(path + "raw/train-all.txt", path + "bmes/train-all.txt")
140+
bmes_tag(path + "raw/dev.txt", path + "bmes/dev.txt")
141+
bmes_tag(path + "raw/test.txt", path + "bmes/test.txt")
142+
143+
144+
def convert_sighan2005_dataset(dataset):
145+
global sighan05_root
146+
root = os.path.join(data_path, dataset)
147+
make_sure_path_exists(root)
148+
make_sure_path_exists(root + "/raw")
149+
file_path = "{}/{}_training.utf8".format(sighan05_root, dataset)
150+
convert_file(
151+
file_path, "{}/raw/train-all.txt".format(root), is_traditional(dataset), True
152+
)
153+
if dataset == "as":
154+
file_path = "{}/{}_testing_gold.utf8".format(sighan05_root, dataset)
155+
else:
156+
file_path = "{}/{}_test_gold.utf8".format(sighan05_root, dataset)
157+
convert_file(
158+
file_path, "{}/raw/test.txt".format(root), is_traditional(dataset), False
159+
)
160+
split_train_dev(dataset)
161+
162+
163+
def convert_sighan2008_dataset(dataset, utf=16):
164+
global sighan08_root
165+
root = os.path.join(data_path, dataset)
166+
make_sure_path_exists(root)
167+
make_sure_path_exists(root + "/raw")
168+
convert_file(
169+
"{}/{}_train_utf{}.seg".format(sighan08_root, dataset, utf),
170+
"{}/raw/train-all.txt".format(root),
171+
is_traditional(dataset),
172+
True,
173+
"utf-{}".format(utf),
174+
)
175+
convert_file(
176+
"{}/{}_seg_truth&resource/{}_truth_utf{}.seg".format(
177+
sighan08_root, dataset, dataset, utf
178+
),
179+
"{}/raw/test.txt".format(root),
180+
is_traditional(dataset),
181+
False,
182+
"utf-{}".format(utf),
183+
)
184+
split_train_dev(dataset)
185+
186+
187+
def extract_conll(src, out):
188+
words = []
189+
with open(src, encoding="utf-8") as src, open(out, "w", encoding="utf-8") as out:
190+
for line in src:
191+
line = line.strip()
192+
if len(line) == 0:
193+
out.write(" ".join(words) + "\n")
194+
words = []
195+
continue
196+
cells = line.split()
197+
words.append(cells[1])
198+
199+
200+
def make_joint_corpus(datasets, joint):
201+
parts = ["dev", "test", "train", "train-all"]
202+
for part in parts:
203+
old_file = "{}/{}/raw/{}.txt".format(data_path, joint, part)
204+
if os.path.exists(old_file):
205+
os.remove(old_file)
206+
elif not os.path.exists(os.path.dirname(old_file)):
207+
os.makedirs(os.path.dirname(old_file))
208+
for name in datasets:
209+
append_tags(
210+
os.path.join(data_path, name, "raw"),
211+
os.path.dirname(old_file),
212+
name,
213+
part,
214+
encode="utf-8",
215+
)
216+
217+
218+
def convert_all_sighan2005(datasets):
219+
for dataset in datasets:
220+
print(("Converting sighan bakeoff 2005 corpus: {}".format(dataset)))
221+
convert_sighan2005_dataset(dataset)
222+
make_bmes(dataset)
223+
224+
225+
def convert_all_sighan2008(datasets):
226+
for dataset in datasets:
227+
print(("Converting sighan bakeoff 2008 corpus: {}".format(dataset)))
228+
convert_sighan2008_dataset(dataset, 16)
229+
make_bmes(dataset)
230+
231+
232+
if __name__ == "__main__":
233+
parser = argparse.ArgumentParser()
234+
# fmt: off
235+
parser.add_argument("--sighan05", required=True, type=str, help="path to sighan2005 dataset")
236+
parser.add_argument("--sighan08", required=True, type=str, help="path to sighan2008 dataset")
237+
parser.add_argument("--data_path", required=True, type=str, help="path to save dataset")
238+
# fmt: on
239+
240+
args, _ = parser.parse_known_args()
241+
sighan05_root = args.sighan05
242+
sighan08_root = args.sighan08
243+
data_path = args.data_path
244+
245+
print("Converting sighan2005 Simplified Chinese corpus")
246+
datasets = "pku", "msr", "as", "cityu"
247+
convert_all_sighan2005(datasets)
248+
249+
print("Combining sighan2005 corpus to one joint Simplified Chinese corpus")
250+
datasets = "pku", "msr", "as", "cityu"
251+
make_joint_corpus(datasets, "joint-sighan2005")
252+
make_bmes("joint-sighan2005")
253+
254+
# For researchers who have access to sighan2008 corpus, use official corpora please.
255+
print("Converting sighan2008 Simplified Chinese corpus")
256+
datasets = "ctb", "ckip", "cityu", "ncc", "sxu"
257+
convert_all_sighan2008(datasets)
258+
print("Combining those 8 sighan corpora to one joint corpus")
259+
datasets = "pku", "msr", "as", "ctb", "ckip", "cityu", "ncc", "sxu"
260+
make_joint_corpus(datasets, "joint-sighan2008")
261+
make_bmes("joint-sighan2008")
262+

0 commit comments

Comments
 (0)