Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ Metrics
KeypointEndPointError
KeypointAUC
KeypointNME
SacreBLEU
1 change: 1 addition & 0 deletions docs/zh_cn/api/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ Metrics
KeypointEndPointError
KeypointAUC
KeypointNME
SacreBLEU
3 changes: 2 additions & 1 deletion mmeval/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .proposal_recall import ProposalRecall
from .psnr import PeakSignalNoiseRatio
from .rouge import ROUGE
from .sacre_bleu import SacreBLEU
from .sad import SumAbsoluteDifferences
from .single_label import SingleLabelMetric
from .snr import SignalNoiseRatio
Expand All @@ -41,7 +42,7 @@
'AveragePrecision', 'AVAMeanAP', 'BLEU', 'DOTAMeanAP',
'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError',
'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError',
'KeypointAUC', 'KeypointNME'
'KeypointAUC', 'KeypointNME', 'SacreBLEU'
]

_deprecated_msg = (
Expand Down
22 changes: 5 additions & 17 deletions mmeval/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# <https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/bleu.py>`_.
import numpy as np
from collections import Counter
from typing import Callable, List, Optional, Sequence, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple

from mmeval import BaseMetric
from mmeval.metrics.utils import get_n_gram, get_tokenizer, infer_language
from mmeval.metrics.utils import get_n_gram


def _get_brevity_penalty(pred_len: np.array,
Expand Down Expand Up @@ -40,7 +40,7 @@ class BLEU(BaseMetric):
ngram_weights (Sequence[float], optional): Weights used
for unigrams, bigrams, etc. to calculate BLEU score.
If not provided, uniform weights are used. Defaults to None.
tokenizer_fn (Union[Callable, str, None]): A user's own tokenizer function.
tokenizer_fn (Callable, optional): A user's own tokenizer function.
Defaults to None.
New in version 0.3.0.
**kwargs: Keyword parameters passed to :class:`BaseMetric`.
Expand All @@ -66,7 +66,7 @@ def __init__(self,
n_gram: int = 4,
smooth: bool = False,
ngram_weights: Optional[Sequence[float]] = None,
tokenizer_fn: Union[Callable, str, None] = None,
tokenizer_fn: Callable = None,
**kwargs) -> None:
super().__init__(**kwargs)
self.n_gram = n_gram
Expand All @@ -78,19 +78,10 @@ def __init__(self,
if ngram_weights is None:
ngram_weights = [1.0 / n_gram] * n_gram
self.ngram_weights = ngram_weights

# Select tokenizer according to the entered value.
self.tokenizer_fn = None
if callable(tokenizer_fn):
self.tokenizer_fn = tokenizer_fn
elif isinstance(tokenizer_fn, str):
self.tokenizer_fn = get_tokenizer(tokenizer_fn)
if self.tokenizer_fn is None:
raise ValueError('Right now, `tokenizer_fn` only supports '
"pre-defined 'en' or 'cn'.")
else:
assert tokenizer_fn is None, \
f'`tokenizer_fn` supports Callable, str or None, but not `{type(tokenizer_fn)}`' # noqa: E501
self.tokenizer_fn = str.split

def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -> None: # type: ignore # yapf: disable # noqa: E501
"""Add the intermediate results to ``self._results``.
Expand All @@ -100,9 +91,6 @@ def add(self, predictions: Sequence[str], references: Sequence[Sequence[str]]) -
references (Sequence[Sequence[str]): An iterable of
referenced sentences.
"""
if self.tokenizer_fn is None:
language = infer_language(predictions[0])
self.tokenizer_fn = get_tokenizer(language)
references_token: Sequence[Sequence[Sequence[str]]] = [
[self.tokenizer_fn(line) for line in r] for r in references
]
Expand Down
Loading