Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions bert_score/scorer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import pathlib
import sys

import time
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -44,7 +43,6 @@ def __init__(
`model_type` or `lang`
- :param: `num_layers` (int): the layer of representation to use.
default using the number of layer tuned on WMT16 correlation data
- :param: `verbose` (bool): turn on intermediate status update
- :param: `idf` (bool): a booling to specify whether to use idf or not (this should be True even if `idf_sents` is given)
- :param: `idf_sents` (List of str): list of sentences used to compute the idf weights
- :param: `device` (str): on which the contextual embedding model will be allocated on.
Expand All @@ -54,7 +52,6 @@ def __init__(
- :param: `lang` (str): language of the sentences; has to specify
at least one of `model_type` or `lang`. `lang` needs to be
specified when `rescale_with_baseline` is True.
- :param: `return_hash` (bool): return hash code of the setting
- :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline
- :param: `baseline_path` (str): customized baseline file
- :param: `use_fast_tokenizer` (bool): `use_fast` parameter passed to HF tokenizer
Expand Down Expand Up @@ -179,7 +176,7 @@ def compute_idf(self, sents):

self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads)

def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
def score(self, cands, refs, verbose=False, batch_size=None, return_hash=False):
"""
Args:
- :param: `cands` (list of str): candidate sentences
Expand Down Expand Up @@ -217,6 +214,9 @@ def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False):
idf_dict[self._tokenizer.sep_token_id] = 0
idf_dict[self._tokenizer.cls_token_id] = 0

if batch_size is None:
batch_size = self.batch_size

all_preds = bert_cos_score_idf(
self._model,
refs,
Expand Down