From d1c284c90a00a01575d0b30d3fac7e3b30426df4 Mon Sep 17 00:00:00 2001 From: gpetho Date: Tue, 30 Jul 2024 11:06:40 +0200 Subject: [PATCH 1/2] fix: score uses scorer batch size setting by default --- bert_score/scorer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bert_score/scorer.py b/bert_score/scorer.py index 48e7054..317cf94 100644 --- a/bert_score/scorer.py +++ b/bert_score/scorer.py @@ -1,6 +1,5 @@ import os -import pathlib -import sys + import time import warnings from collections import defaultdict @@ -179,7 +178,7 @@ def compute_idf(self, sents): self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads) - def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False): + def score(self, cands, refs, verbose=False, batch_size=None, return_hash=False): """ Args: - :param: `cands` (list of str): candidate sentences @@ -217,6 +216,9 @@ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False): idf_dict[self._tokenizer.sep_token_id] = 0 idf_dict[self._tokenizer.cls_token_id] = 0 + if batch_size is None: + batch_size = self.batch_size + all_preds = bert_cos_score_idf( self._model, refs, From 6ea1d5cf758abf9783ad65b3c32a5d9b9d438b63 Mon Sep 17 00:00:00 2001 From: gpetho Date: Tue, 30 Jul 2024 11:24:15 +0200 Subject: [PATCH 2/2] removed non-existent params from docstring --- bert_score/scorer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/bert_score/scorer.py b/bert_score/scorer.py index 317cf94..f4f08c9 100644 --- a/bert_score/scorer.py +++ b/bert_score/scorer.py @@ -43,7 +43,6 @@ def __init__( `model_type` or `lang` - :param: `num_layers` (int): the layer of representation to use. default using the number of layer tuned on WMT16 correlation data - - :param: `verbose` (bool): turn on intermediate status update - :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given) - :param: `idf_sents` (List of str): list of sentences used to compute the idf weights - :param: `device` (str): on which the contextual embedding model will be allocated on. @@ -53,7 +52,6 @@ def __init__( - :param: `lang` (str): language of the sentences; has to specify at least one of `model_type` or `lang`. `lang` needs to be specified when `rescale_with_baseline` is True. - - :param: `return_hash` (bool): return hash code of the setting - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline - :param: `baseline_path` (str): customized baseline file - :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer