Skip to content

Commit 5800e02

Browse files
committed
Change the MWT dictionary lookup to only look for lowercasing if the original word matches one of a couple expected casing formats, in which case we can recreate those formats after using the dictionary lookup. Otherwise, you get unexpected tokenizations such as She's -> she 's. #1371
1 parent a15b981 commit 5800e02

File tree

2 files changed

+79
-11
lines changed

2 files changed

+79
-11
lines changed

stanza/models/mwt/trainer.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,29 +111,51 @@ def train_dict(self, pairs):
111111
seen.add(w)
112112
return
113113

114+
def dict_expansion(self, word):
115+
"""
116+
Check the expansion dictionary for the word along with a couple common lowercasings of the word
117+
118+
(Leadingcase and UPPERCASE)
119+
"""
120+
expansion = self.expansion_dict.get(word)
121+
if expansion is not None:
122+
return expansion
123+
124+
if word.isupper():
125+
expansion = self.expansion_dict.get(word.lower())
126+
if expansion is not None:
127+
return expansion.upper()
128+
129+
if word[0].isupper() and word[1:].islower():
130+
expansion = self.expansion_dict.get(word.lower())
131+
if expansion is not None:
132+
return expansion[0].upper() + expansion[1:]
133+
134+
# could build a truecasing model of some kind to handle cRaZyCaSe...
135+
# but that's probably too much effort
136+
return None
137+
114138
def predict_dict(self, words):
115139
""" Predict a list of expansions given words. """
116140
expansions = []
117141
for w in words:
118-
if w in self.expansion_dict:
119-
expansions += [self.expansion_dict[w]]
120-
elif w.lower() in self.expansion_dict:
121-
expansions += [self.expansion_dict[w.lower()]]
142+
expansion = self.dict_expansion(w)
143+
if expansion is not None:
144+
expansions.append(expansion)
122145
else:
123-
expansions += [w]
146+
expansions.append(w)
124147
return expansions
125148

126149
def ensemble(self, cands, other_preds):
127150
""" Ensemble the dict with statistical model predictions. """
128151
expansions = []
129152
assert len(cands) == len(other_preds)
130153
for c, pred in zip(cands, other_preds):
131-
if c in self.expansion_dict:
132-
expansions += [self.expansion_dict[c]]
133-
elif c.lower() in self.expansion_dict:
134-
expansions += [self.expansion_dict[c.lower()]]
154+
expansion = self.dict_expansion(c)
155+
if expansion is not None:
156+
expansions.append(expansion)
135157
else:
136-
expansions += [pred]
158+
expansions.append(pred)
137159
return expansions
138160

139161
def save(self, filename):

stanza/tests/mwt/test_english_corner_cases.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""
2-
Test that an unknown English character doesn't result in bizarre splits
2+
Test a couple English MWT corner cases which might be more widely applicable to other MWT languages
3+
4+
- unknown English character doesn't result in bizarre splits
5+
- Casing or CASING doesn't get lost in the dictionary lookup
36
47
In the English UD datasets, the MWT are composed exactly of the
58
subwords, so the MWT model should be chopping up the input text rather
69
than generating new text.
10+
11+
Furthermore, SHE'S and She's should be split "SHE 'S" and "She 's" respectively
712
"""
813

914
import pytest
@@ -40,3 +45,44 @@ def test_mwt_unknown_char():
4045
assert doc.sentences[0].tokens[3].text == possessive
4146
assert len(doc.sentences[0].tokens[3].words) == 2
4247
assert "".join(x.text for x in doc.sentences[0].tokens[3].words) == possessive
48+
49+
50+
def test_english_mwt_casing():
51+
"""
52+
Test that for a word where the lowercase split is known, the correct casing is still used
53+
54+
Once upon a time, the logic used in the MWT expander would split
55+
SHE'S -> she 's
56+
57+
which is a very surprising tokenization to people expecting
58+
the original text in the output document
59+
"""
60+
pipeline = stanza.Pipeline(processors='tokenize,mwt', dir=TEST_MODELS_DIR, lang='en', download_method=None)
61+
62+
mwt_trainer = pipeline.processors['mwt']._trainer
63+
for i in range(1, 20):
64+
# many test cases follow this pattern for some reason,
65+
# so we should proactively look for a test case which hasn't
66+
# made its way into the MWT dictionary
67+
unknown_name = "jennife" + "r" * i + "'s"
68+
if unknown_name not in mwt_trainer.expansion_dict and unknown_name.upper() not in mwt_trainer.expansion_dict:
69+
unknown_name = unknown_name.upper()
70+
break
71+
else:
72+
raise AssertionError("Need a new heuristic for the unknown word in the English MWT!")
73+
74+
# this SHOULD show up in the expansion dict
75+
assert "she's" in mwt_trainer.expansion_dict, "Expected |she's| to be in the English MWT expansion dict... perhaps find a different test case"
76+
77+
text = [x.text for x in pipeline("JENNIFER HAS NICE ANTENNAE").sentences[0].words]
78+
assert text == ['JENNIFER', 'HAS', 'NICE', 'ANTENNAE']
79+
80+
text = [x.text for x in pipeline(unknown_name + " GOT NICE ANTENNAE").sentences[0].words]
81+
assert text == [unknown_name[:-2], "'S", 'GOT', 'NICE', 'ANTENNAE']
82+
83+
text = [x.text for x in pipeline("SHE'S GOT NICE ANTENNAE").sentences[0].words]
84+
assert text == ['SHE', "'S", 'GOT', 'NICE', 'ANTENNAE']
85+
86+
text = [x.text for x in pipeline("She's GOT NICE ANTENNAE").sentences[0].words]
87+
assert text == ['She', "'s", 'GOT', 'NICE', 'ANTENNAE']
88+

0 commit comments

Comments
 (0)