@@ -23,7 +23,7 @@ class WordAccuracy(BaseMetric):
2323 not cared characters. Defaults to '[^A-Za-z0-9\u4e00-\u9fa5]'
2424 **kwargs: Keyword parameters passed to :class:`BaseMetric`.
2525
26- Example :
26+ Examples :
2727 >>> from mmeval import WordAccuracy
2828 >>> metric = WordAccuracy()
2929 >>> metric(['hello', 'hello', 'hello'], ['hello', 'HELLO', '$HELLO$'])
@@ -46,19 +46,19 @@ def __init__(self,
4646 assert isinstance (mode , (str , list ))
4747 if isinstance (mode , str ):
4848 mode = [mode ]
49- assert all ([isinstance (item , str ) for item in mode ])
50- assert set (mode ).issubset (
51- {'exact' , 'ignore_case' , 'ignore_case_symbol' })
49+ assert all (isinstance (item , str ) for item in mode )
5250 self .mode = set (mode ) # type: ignore
51+ assert set (self .mode ).issubset (
52+ {'exact' , 'ignore_case' , 'ignore_case_symbol' })
5353
54- def add (self , predictions : Sequence [str ], labels : Sequence [str ]) -> None : # type: ignore # yapf: disable # noqa: E501
54+ def add (self , predictions : Sequence [str ], groundtruths : Sequence [str ]) -> None : # type: ignore # yapf: disable # noqa: E501
5555 """Process one batch of data and predictions.
5656
5757 Args:
5858 predictions (list[str]): The prediction texts.
59- labels (list[str]): The ground truth texts.
59+ groundtruths (list[str]): The ground truth texts.
6060 """
61- for pred , label in zip (predictions , labels ):
61+ for pred , label in zip (predictions , groundtruths ):
6262 num , ignore_case_num , ignore_case_symbol_num = 0 , 0 , 0
6363 if 'exact' in self .mode :
6464 num = pred == label
@@ -85,22 +85,23 @@ def compute_metric(self, results: List[Tuple[int, int, int]]) -> Dict:
8585
8686 - accuracy (float): Accuracy at word level.
8787 - ignore_case_accuracy (float): Accuracy at word level, ignoring
88- letter case.
88+ letter case.
8989 - ignore_case_symbol_accuracy (float): Accuracy at word level,
90- ignoring letter case and symbol.
90+ ignoring letter case and symbol.
9191 """
92- eval_res = {}
92+ metric_results = {}
9393 gt_word_num = max (len (results ), 1.0 )
9494 exact_sum , ignore_case_sum , ignore_case_symbol_sum = 0.0 , 0.0 , 0.0
9595 for exact , ignore_case , ignore_case_symbol in results :
9696 exact_sum += exact
9797 ignore_case_sum += ignore_case
9898 ignore_case_symbol_sum += ignore_case_symbol
9999 if 'exact' in self .mode :
100- eval_res ['accuracy' ] = exact_sum / gt_word_num
100+ metric_results ['accuracy' ] = exact_sum / gt_word_num
101101 if 'ignore_case' in self .mode :
102- eval_res ['ignore_case_accuracy' ] = ignore_case_sum / gt_word_num
102+ metric_results [
103+ 'ignore_case_accuracy' ] = ignore_case_sum / gt_word_num
103104 if 'ignore_case_symbol' in self .mode :
104- eval_res ['ignore_case_symbol_accuracy' ] = \
105+ metric_results ['ignore_case_symbol_accuracy' ] = \
105106 ignore_case_symbol_sum / gt_word_num
106- return eval_res
107+ return metric_results
0 commit comments