Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 183 additions & 1 deletion lm_eval/api/metrics.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,45 @@
import ipaddress
import json
import logging
import math
import os
import random
import re
import string
import uuid
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, TypeVar
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar

import numpy as np
import sacrebleu
from jsonschema import Draft202012Validator, FormatChecker, SchemaError, ValidationError

from lm_eval.api.registry import register_aggregation, register_metric


eval_logger = logging.getLogger(__name__)


# Initialize the FormatChecker
format_checker = FormatChecker()


# Add custom format checkers
@format_checker.checks("ipv4")
def ipv4_check(value):
ipaddress.IPv4Address(value)


@format_checker.checks("ipv6")
def ipv6_check(value):
ipaddress.IPv6Address(value)


@format_checker.checks("uuid")
def uuid_check(value):
uuid.UUID(value)


T = TypeVar("T")

eval_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -309,6 +336,161 @@ def bypass(items):
return None


def is_json_schema_valid(schema: dict):
"""
Check if a JSON schema is valid.

:param schema: A JSON schema.
:return: True if the schema is valid, False otherwise.
"""
try:
# Check if the schema is valid
Draft202012Validator.check_schema(schema)
return True
except SchemaError:
return False


def schema_conform_with_format_checker(
instance: Dict[str, Any], schema: Dict[str, Any]
) -> bool:
"""
Validate a JSON instance against a schema with enhanced format checking.

:param schema: The JSON schema to validate against.
:param instance: The JSON instance to validate.
:raises ValidationError: If the validation fails.
"""
# first check if the schema is valid
if not is_json_schema_valid(schema):
raise ValidationError("The JSON schema is invalid.")
validator = Draft202012Validator(schema, format_checker=format_checker)
try:
validator.validate(instance)
except ValidationError as e:
raise ValidationError(e.message)
return True


@register_metric(
metric="json_validity",
higher_is_better=True,
output_type=["generate_until"],
aggregation="mean",
)
def json_validity(
references: list[str], predictions: list[str], strip: bool = True
) -> bool:
assert len(predictions) == 1, (
"Currently, we don't support pass@k for JSON schema validation."
)
prediction = predictions[0] # Since predictions is a list of lists

if strip:
prediction = prediction.strip().strip("```").strip("json").strip()

try:
json.loads(prediction)
except json.JSONDecodeError:
return False
return True


@register_metric(
metric="grammar_compliance",
higher_is_better=True,
output_type=["generate_until"],
aggregation="mean",
)
def grammar_compliance(
references: list[str],
predictions: list[str],
grammar_file_path: str,
grammar_type: str,
tokenizer: str = None,
) -> bool:
assert len(references) == 1, (
"We only have one reference for this task, which is the JSON schema."
)
assert len(predictions) == 1, (
"Currently, we don't support pass@k for JSON schema validation."
)

prediction = predictions[0] # Since predictions is a list of lists

with open(grammar_file_path, "r") as f:
grammar_str = f.read().strip()

if grammar_type == "json":
json_schema = json.loads(grammar_str)
try:
json_obj = json.loads(prediction.strip().strip("```").strip("json"))
except json.JSONDecodeError:
return False

try:
schema_conform = schema_conform_with_format_checker(json_obj, json_schema)
except Exception as e:
eval_logger.error(f"Error: {e}")
return False

return schema_conform

if grammar_type == "regex":
return bool(re.fullmatch(grammar_str, prediction.strip()))

if grammar_type == "gbnf":
try:
import xgrammar as xgr
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=tokenizer.vocab_size
)
grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
compiled_grammar = grammar_compiler.compile_grammar(grammar_str)
matcher = xgr.GrammarMatcher(compiled_grammar)

return matcher.accept_string(prediction.strip())

except Exception:
return False

raise ValueError(f"Unknown grammar type: {grammar_type}")


@register_metric(
metric="json_answer_match",
higher_is_better=True,
output_type=["generate_until"],
aggregation="mean",
)
def json_answer_match(
predictions,
references,
target_field,
):
extracted_predictions = [""] * len(predictions)
for i in range(len(predictions)):
try:
extracted_predictions[i] = json.loads(predictions[i].strip())[target_field]
except (json.JSONDecodeError, KeyError):
continue

# This is an ad hoc solution. We need to generalize it.
if isinstance(references[0], str):
extracted_predictions = list(map(str, extracted_predictions))

extracted_predictions = np.array(extracted_predictions)
references = np.array(references)

score_list = extracted_predictions == references

return {"json_answer_match": np.mean(score_list)}


@register_metric(
metric="mcc",
higher_is_better=True,
Expand Down
1 change: 1 addition & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
gguf,
hf_audiolm,
hf_steered,
hf_structured,
hf_vlms,
huggingface,
ibm_watsonx_ai,
Expand Down
75 changes: 75 additions & 0 deletions lm_eval/models/hf_structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging

import xgrammar as xgr

from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM


eval_logger = logging.getLogger(__name__)

ALL_GRAMMAR_TYPES = ("gbnf", "json", "regex")


@register_model("hf-structured")
class HFStructuredLM(HFLM):
"""
An abstracted Hugging Face model class for structured LMs.
"""

def _get_logits_processor(self, grammar_file_path, grammar_type):
if grammar_type not in ALL_GRAMMAR_TYPES:
raise ValueError(
f"Got invalid grammar_type '{grammar_type}', must be in '{','.join(ALL_GRAMMAR_TYPES)}'"
)

tokenizer_info = xgr.TokenizerInfo.from_huggingface(
self.tokenizer, vocab_size=self.config.vocab_size
)
compiler = xgr.GrammarCompiler(tokenizer_info)

with open(grammar_file_path, "r") as f:
grammar_str = f.read().strip()

if grammar_type == "gbnf":
compiled_grammar = compiler.compile_grammar(grammar_str)
elif grammar_type == "json":
compiled_grammar = compiler.compile_json_schema(grammar_str)
elif grammar_type == "regex":
compiled_grammar = compiler.compile_regex(grammar_str)

return xgr.contrib.hf.LogitsProcessor(compiled_grammar)

def _model_generate(
self,
context,
max_length,
stop,
grammar_file_path,
grammar_type,
**generation_kwargs,
):
# temperature = 0.0 if not set
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)

# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False

if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")

logits_processor = self._get_logits_processor(grammar_file_path, grammar_type)

return self.model.generate(
input_ids=context,
max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
logits_processor=[logits_processor],
**generation_kwargs,
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ dependencies = [
"datasets>=2.16.0,<4.0",
"evaluate>=0.4.0",
"jsonlines",
"jsonschema>4.18.0",
"xgrammar",
"numexpr",
"peft>=0.2.0",
"pybind11>=2.6.2",
Expand Down
Loading