@@ -21,17 +21,17 @@ class OneMinusNormEditDistance(BaseMetric):
2121
2222 - unchanged: Do not change prediction texts and labels.
2323 - upper: Convert prediction texts and labels into uppercase
24- characters.
24+ characters.
2525 - lower: Convert prediction texts and labels into lowercase
26- characters.
26+ characters.
2727
2828 Usually, it only works for English characters. Defaults to
2929 'unchanged'.
3030 invalid_symbol (str): A regular expression to filter out invalid or
31- not cared characters. Defaults to '[^A-Z^a-z^0-9^ \u4e00-\u9fa5]'.
31+ not cared characters. Defaults to '[^A-Za-z0-9 \u4e00-\u9fa5]'.
3232 **kwargs: Keyword parameters passed to :class:`BaseMetric`.
3333
34- Example :
34+ Examples :
3535 >>> from mmeval import OneMinusNormEditDistance
3636 >>> metric = OneMinusNormEditDistance()
3737 >>> metric(['helL', 'HEL'], ['hello', 'HELLO'])
@@ -43,22 +43,22 @@ class OneMinusNormEditDistance(BaseMetric):
4343
4444 def __init__ (self ,
4545 letter_case : str = 'unchanged' ,
46- invalid_symbol : str = '[^A-Z^a-z^0-9^ \u4e00 -\u9fa5 ]' ,
46+ invalid_symbol : str = '[^A-Za-z0-9 \u4e00 -\u9fa5 ]' ,
4747 ** kwargs ):
4848 super ().__init__ (** kwargs )
4949
5050 assert letter_case in ['unchanged' , 'upper' , 'lower' ]
5151 self .letter_case = letter_case
5252 self .invalid_symbol = re .compile (invalid_symbol )
5353
54- def add (self , predictions : Sequence [str ], labels : Sequence [str ]): # type: ignore # yapf: disable # noqa: E501
54+ def add (self , predictions : Sequence [str ], groundtruths : Sequence [str ]): # type: ignore # yapf: disable # noqa: E501
5555 """Process one batch of data and predictions.
5656
5757 Args:
5858 predictions (list[str]): The prediction texts.
59- labels (list[str]): The ground truth texts.
59+ groundtruths (list[str]): The ground truth texts.
6060 """
61- for pred , label in zip (predictions , labels ):
61+ for pred , label in zip (predictions , groundtruths ):
6262 if self .letter_case in ['upper' , 'lower' ]:
6363 pred = getattr (pred , self .letter_case )()
6464 label = getattr (label , self .letter_case )()
@@ -75,11 +75,12 @@ def compute_metric(self, results: List[float]) -> Dict:
7575
7676 Returns:
7777 dict[str, float]: Nested dicts as results.
78- - 1-N.E.D (float): One minus the normalized edit distance.
78+
79+ - 1-N.E.D (float): One minus the normalized edit distance.
7980 """
8081 gt_word_num = len (results )
8182 norm_ed_sum = sum (results )
8283 normalized_edit_distance = norm_ed_sum / max (1.0 , gt_word_num )
83- eval_res = {}
84- eval_res ['1-N.E.D' ] = 1.0 - normalized_edit_distance
85- return eval_res
84+ metric_results = {}
85+ metric_results ['1-N.E.D' ] = 1.0 - normalized_edit_distance
86+ return metric_results
0 commit comments