Skip to content

Commit 8781329

Browse files
committed
fix comment
1 parent 89a776a commit 8781329

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

mmeval/metrics/one_minus_norm_edit_distance.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@ class OneMinusNormEditDistance(BaseMetric):
2121
2222
- unchanged: Do not change prediction texts and labels.
2323
- upper: Convert prediction texts and labels into uppercase
24-
characters.
24+
characters.
2525
- lower: Convert prediction texts and labels into lowercase
26-
characters.
26+
characters.
2727
2828
Usually, it only works for English characters. Defaults to
2929
'unchanged'.
3030
invalid_symbol (str): A regular expression to filter out invalid or
31-
not cared characters. Defaults to '[^A-Z^a-z^0-9^\u4e00-\u9fa5]'.
31+
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'.
3232
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
3333
34-
Example:
34+
Examples:
3535
>>> from mmeval import OneMinusNormEditDistance
3636
>>> metric = OneMinusNormEditDistance()
3737
>>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
@@ -43,22 +43,22 @@ class OneMinusNormEditDistance(BaseMetric):
4343

4444
def __init__(self,
4545
letter_case: str = 'unchanged',
46-
invalid_symbol: str = '[^A-Z^a-z^0-9^\u4e00-\u9fa5]',
46+
invalid_symbol: str = '[^A-Za-z0-9\u4e00-\u9fa5]',
4747
**kwargs):
4848
super().__init__(**kwargs)
4949

5050
assert letter_case in ['unchanged', 'upper', 'lower']
5151
self.letter_case = letter_case
5252
self.invalid_symbol = re.compile(invalid_symbol)
5353

54-
def add(self, predictions: Sequence[str], labels: Sequence[str]): # type: ignore # yapf: disable # noqa: E501
54+
def add(self, predictions: Sequence[str], groundtruths: Sequence[str]): # type: ignore # yapf: disable # noqa: E501
5555
"""Process one batch of data and predictions.
5656
5757
Args:
5858
predictions (list[str]): The prediction texts.
59-
labels (list[str]): The ground truth texts.
59+
groundtruths (list[str]): The ground truth texts.
6060
"""
61-
for pred, label in zip(predictions, labels):
61+
for pred, label in zip(predictions, groundtruths):
6262
if self.letter_case in ['upper', 'lower']:
6363
pred = getattr(pred, self.letter_case)()
6464
label = getattr(label, self.letter_case)()
@@ -75,11 +75,12 @@ def compute_metric(self, results: List[float]) -> Dict:
7575
7676
Returns:
7777
dict[str, float]: Nested dicts as results.
78-
- 1-N.E.D (float): One minus the normalized edit distance.
78+
79+
- 1-N.E.D (float): One minus the normalized edit distance.
7980
"""
8081
gt_word_num = len(results)
8182
norm_ed_sum = sum(results)
8283
normalized_edit_distance = norm_ed_sum / max(1.0, gt_word_num)
83-
eval_res = {}
84-
eval_res['1-N.E.D'] = 1.0 - normalized_edit_distance
85-
return eval_res
84+
metric_results = {}
85+
metric_results['1-N.E.D'] = 1.0 - normalized_edit_distance
86+
return metric_results

0 commit comments

Comments
 (0)