diff --git a/t5/data/glue_utils.py b/t5/data/glue_utils.py index 66378ea0..bb933bde 100644 --- a/t5/data/glue_utils.py +++ b/t5/data/glue_utils.py @@ -113,6 +113,7 @@ def get_glue_postprocess_fn(builder_config): return functools.partial( postprocessors.string_label_to_class_id, label_classes=builder_config.label_classes, + use_default=False, ) GLUE_METRICS = collections.OrderedDict([ diff --git a/t5/data/postprocessors.py b/t5/data/postprocessors.py index 5c4aaa31..7f413e4b 100644 --- a/t5/data/postprocessors.py +++ b/t5/data/postprocessors.py @@ -41,12 +41,16 @@ def lower_text(string, **unused_kwargs): def string_label_to_class_id(string_label, label_classes, default=-1, + use_default=True, **unused_kwargs): """Returns index of string_label in label_classes or default if not found.""" if string_label in label_classes: return label_classes.index(string_label) else: - return default + if use_default: + return default + else: + return string_label def multirc(string_label, example=None, is_target=False): diff --git a/t5/data/postprocessors_test.py b/t5/data/postprocessors_test.py index c315d6bb..9b254718 100644 --- a/t5/data/postprocessors_test.py +++ b/t5/data/postprocessors_test.py @@ -36,6 +36,8 @@ def test_string_label_to_class_id(self): self.assertEqual(postprocessors.string_label_to_class_id("two", cls), 1) self.assertEqual(postprocessors.string_label_to_class_id("foo", cls), -1) self.assertEqual(postprocessors.string_label_to_class_id("foo", cls, 2), 2) + self.assertEqual(postprocessors.string_label_to_class_id( + "foo", cls, use_default=False), "foo") def test_multirc(self): self.assertDictEqual(