Skip to content

Commit 0c40a5c

Browse files
committed
fix comment
1 parent ef56e48 commit 0c40a5c

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

mmeval/metrics/word_accuracy.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class WordAccuracy(BaseMetric):
2323
not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'
2424
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
2525
26-
Example:
26+
Examples:
2727
>>> from mmeval import WordAccuracy
2828
>>> metric = WordAccuracy()
2929
>>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
@@ -46,19 +46,19 @@ def __init__(self,
4646
assert isinstance(mode, (str, list))
4747
if isinstance(mode, str):
4848
mode = [mode]
49-
assert all([isinstance(item, str) for item in mode])
50-
assert set(mode).issubset(
51-
{'exact', 'ignore_case', 'ignore_case_symbol'})
49+
assert all(isinstance(item, str) for item in mode)
5250
self.mode = set(mode) # type: ignore
51+
assert set(self.mode).issubset(
52+
{'exact', 'ignore_case', 'ignore_case_symbol'})
5353

54-
def add(self, predictions: Sequence[str], labels: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
54+
def add(self, predictions: Sequence[str], groundtruths: Sequence[str]) -> None: # type: ignore # yapf: disable # noqa: E501
5555
"""Process one batch of data and predictions.
5656
5757
Args:
5858
predictions (list[str]): The prediction texts.
59-
labels (list[str]): The ground truth texts.
59+
groundtruths (list[str]): The ground truth texts.
6060
"""
61-
for pred, label in zip(predictions, labels):
61+
for pred, label in zip(predictions, groundtruths):
6262
num, ignore_case_num, ignore_case_symbol_num = 0, 0, 0
6363
if 'exact' in self.mode:
6464
num = pred == label
@@ -85,22 +85,23 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict:
8585
8686
- accuracy (float): Accuracy at word level.
8787
- ignore_case_accuracy (float): Accuracy at word level, ignoring
88-
letter case.
88+
letter case.
8989
- ignore_case_symbol_accuracy (float): Accuracy at word level,
90-
ignoring letter case and symbol.
90+
ignoring letter case and symbol.
9191
"""
92-
eval_res = {}
92+
metric_results = {}
9393
gt_word_num = max(len(results), 1.0)
9494
exact_sum, ignore_case_sum, ignore_case_symbol_sum = 0.0, 0.0, 0.0
9595
for exact, ignore_case, ignore_case_symbol in results:
9696
exact_sum += exact
9797
ignore_case_sum += ignore_case
9898
ignore_case_symbol_sum += ignore_case_symbol
9999
if 'exact' in self.mode:
100-
eval_res['accuracy'] = exact_sum / gt_word_num
100+
metric_results['accuracy'] = exact_sum / gt_word_num
101101
if 'ignore_case' in self.mode:
102-
eval_res['ignore_case_accuracy'] = ignore_case_sum / gt_word_num
102+
metric_results[
103+
'ignore_case_accuracy'] = ignore_case_sum / gt_word_num
103104
if 'ignore_case_symbol' in self.mode:
104-
eval_res['ignore_case_symbol_accuracy'] =\
105+
metric_results['ignore_case_symbol_accuracy'] =\
105106
ignore_case_symbol_sum / gt_word_num
106-
return eval_res
107+
return metric_results

0 commit comments

Comments
 (0)