diff --git a/lm_eval/api/metrics.py b/lm_eval/api/metrics.py index f01b181860..4addf38311 100644 --- a/lm_eval/api/metrics.py +++ b/lm_eval/api/metrics.py @@ -1,18 +1,45 @@ +import ipaddress +import json import logging import math import os import random import re import string +import uuid from collections.abc import Iterable -from typing import Callable, List, Optional, Sequence, TypeVar +from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar import numpy as np import sacrebleu +from jsonschema import Draft202012Validator, FormatChecker, SchemaError, ValidationError from lm_eval.api.registry import register_aggregation, register_metric +eval_logger = logging.getLogger(__name__) + + +# Initialize the FormatChecker +format_checker = FormatChecker() + + +# Add custom format checkers +@format_checker.checks("ipv4") +def ipv4_check(value): + ipaddress.IPv4Address(value) + + +@format_checker.checks("ipv6") +def ipv6_check(value): + ipaddress.IPv6Address(value) + + +@format_checker.checks("uuid") +def uuid_check(value): + uuid.UUID(value) + + T = TypeVar("T") eval_logger = logging.getLogger(__name__) @@ -309,6 +336,161 @@ def bypass(items): return None +def is_json_schema_valid(schema: dict): + """ + Check if a JSON schema is valid. + + :param schema: A JSON schema. + :return: True if the schema is valid, False otherwise. + """ + try: + # Check if the schema is valid + Draft202012Validator.check_schema(schema) + return True + except SchemaError: + return False + + +def schema_conform_with_format_checker( + instance: Dict[str, Any], schema: Dict[str, Any] +) -> bool: + """ + Validate a JSON instance against a schema with enhanced format checking. + + :param schema: The JSON schema to validate against. + :param instance: The JSON instance to validate. + :raises ValidationError: If the validation fails. + """ + # first check if the schema is valid + if not is_json_schema_valid(schema): + raise ValidationError("The JSON schema is invalid.") + validator = Draft202012Validator(schema, format_checker=format_checker) + try: + validator.validate(instance) + except ValidationError as e: + raise ValidationError(e.message) + return True + + +@register_metric( + metric="json_validity", + higher_is_better=True, + output_type=["generate_until"], + aggregation="mean", +) +def json_validity( + references: list[str], predictions: list[str], strip: bool = True +) -> bool: + assert len(predictions) == 1, ( + "Currently, we don't support pass@k for JSON schema validation." + ) + prediction = predictions[0] # Since predictions is a list of lists + + if strip: + prediction = prediction.strip().strip("```").strip("json").strip() + + try: + json.loads(prediction) + except json.JSONDecodeError: + return False + return True + + +@register_metric( + metric="grammar_compliance", + higher_is_better=True, + output_type=["generate_until"], + aggregation="mean", +) +def grammar_compliance( + references: list[str], + predictions: list[str], + grammar_file_path: str, + grammar_type: str, + tokenizer: str = None, +) -> bool: + assert len(references) == 1, ( + "We only have one reference for this task, which is the JSON schema." + ) + assert len(predictions) == 1, ( + "Currently, we don't support pass@k for JSON schema validation." + ) + + prediction = predictions[0] # Since predictions is a list of lists + + with open(grammar_file_path, "r") as f: + grammar_str = f.read().strip() + + if grammar_type == "json": + json_schema = json.loads(grammar_str) + try: + json_obj = json.loads(prediction.strip().strip("```").strip("json")) + except json.JSONDecodeError: + return False + + try: + schema_conform = schema_conform_with_format_checker(json_obj, json_schema) + except Exception as e: + eval_logger.error(f"Error: {e}") + return False + + return schema_conform + + if grammar_type == "regex": + return bool(re.fullmatch(grammar_str, prediction.strip())) + + if grammar_type == "gbnf": + try: + import xgrammar as xgr + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B") + + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, vocab_size=tokenizer.vocab_size + ) + grammar_compiler = xgr.GrammarCompiler(tokenizer_info) + compiled_grammar = grammar_compiler.compile_grammar(grammar_str) + matcher = xgr.GrammarMatcher(compiled_grammar) + + return matcher.accept_string(prediction.strip()) + + except Exception: + return False + + raise ValueError(f"Unknown grammar type: {grammar_type}") + + +@register_metric( + metric="json_answer_match", + higher_is_better=True, + output_type=["generate_until"], + aggregation="mean", +) +def json_answer_match( + predictions, + references, + target_field, +): + extracted_predictions = [""] * len(predictions) + for i in range(len(predictions)): + try: + extracted_predictions[i] = json.loads(predictions[i].strip())[target_field] + except (json.JSONDecodeError, KeyError): + continue + + # This is an ad hoc solution. We need to generalize it. + if isinstance(references[0], str): + extracted_predictions = list(map(str, extracted_predictions)) + + extracted_predictions = np.array(extracted_predictions) + references = np.array(references) + + score_list = extracted_predictions == references + + return {"json_answer_match": np.mean(score_list)} + + @register_metric( metric="mcc", higher_is_better=True, diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index abedc5535e..d58244f397 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -5,6 +5,7 @@ gguf, hf_audiolm, hf_steered, + hf_structured, hf_vlms, huggingface, ibm_watsonx_ai, diff --git a/lm_eval/models/hf_structured.py b/lm_eval/models/hf_structured.py new file mode 100644 index 0000000000..e814e87380 --- /dev/null +++ b/lm_eval/models/hf_structured.py @@ -0,0 +1,75 @@ +import logging + +import xgrammar as xgr + +from lm_eval.api.registry import register_model +from lm_eval.models.huggingface import HFLM + + +eval_logger = logging.getLogger(__name__) + +ALL_GRAMMAR_TYPES = ("gbnf", "json", "regex") + + +@register_model("hf-structured") +class HFStructuredLM(HFLM): + """ + An abstracted Hugging Face model class for structured LMs. + """ + + def _get_logits_processor(self, grammar_file_path, grammar_type): + if grammar_type not in ALL_GRAMMAR_TYPES: + raise ValueError( + f"Got invalid grammar_type '{grammar_type}', must be in '{','.join(ALL_GRAMMAR_TYPES)}'" + ) + + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + self.tokenizer, vocab_size=self.config.vocab_size + ) + compiler = xgr.GrammarCompiler(tokenizer_info) + + with open(grammar_file_path, "r") as f: + grammar_str = f.read().strip() + + if grammar_type == "gbnf": + compiled_grammar = compiler.compile_grammar(grammar_str) + elif grammar_type == "json": + compiled_grammar = compiler.compile_json_schema(grammar_str) + elif grammar_type == "regex": + compiled_grammar = compiler.compile_regex(grammar_str) + + return xgr.contrib.hf.LogitsProcessor(compiled_grammar) + + def _model_generate( + self, + context, + max_length, + stop, + grammar_file_path, + grammar_type, + **generation_kwargs, + ): + # temperature = 0.0 if not set + # if do_sample is false and temp==0.0: + # remove temperature, as do_sample=False takes care of this + # and we don't want a warning from HF + generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0) + do_sample = generation_kwargs.get("do_sample", None) + + # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies + if generation_kwargs.get("temperature") == 0.0 and do_sample is None: + generation_kwargs["do_sample"] = do_sample = False + + if do_sample is False and generation_kwargs.get("temperature") == 0.0: + generation_kwargs.pop("temperature") + + logits_processor = self._get_logits_processor(grammar_file_path, grammar_type) + + return self.model.generate( + input_ids=context, + max_length=max_length, + pad_token_id=self.tokenizer.pad_token_id, + use_cache=True, + logits_processor=[logits_processor], + **generation_kwargs, + ) diff --git a/pyproject.toml b/pyproject.toml index c6dabf4c09..934d2063e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ "datasets>=2.16.0,<4.0", "evaluate>=0.4.0", "jsonlines", + "jsonschema>4.18.0", + "xgrammar", "numexpr", "peft>=0.2.0", "pybind11>=2.6.2",