Skip to content

Fix issue #2044: Improve output parsing and timeout for local LLMs #2046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/ragas/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,21 @@ class RagasOutputParserException(RagasException):
Exception raised when the output parser fails to parse the output.
"""

def __init__(self):
msg = "The output parser failed to parse the output including retries."
def __init__(self, details: str = None):
base_msg = "The output parser failed to parse the output including retries."
if details:
msg = f"{base_msg} Details: {details}"
else:
msg = base_msg

# Add suggestions for local LLMs
msg += "\nFor local LLMs, consider the following:\n" \
"1. Increase the timeout in RunConfig (default is now 300 seconds)\n" \
"2. Use a more capable local model that can better follow JSON formatting instructions\n" \
"3. Reduce batch size to process fewer examples at once\n" \
"4. For metrics that require structured output (context_recall, faithfulness, context_precision), " \
"consider using simpler metrics like answer_correctness and answer_similarity if the issues persist"

super().__init__(msg)


Expand Down
92 changes: 84 additions & 8 deletions src/ragas/prompt/pydantic_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,33 @@ async def parse_output_string(
prompt_value: PromptValue,
llm: BaseRagasLLM,
callbacks: Callbacks,
retries_left: int = 1,
retries_left: int = 3, # Increased default retries from 1 to 3
) -> OutputModel:
import json
import logging

logger = logging.getLogger(__name__)
callbacks = callbacks or []

try:
# First attempt: Extract and parse JSON directly
jsonstr = extract_json(output_string)
result = super().parse(jsonstr)
except OutputParserException:
if retries_left != 0:
try:
result = super().parse(jsonstr)
return result
except (OutputParserException, json.JSONDecodeError) as e:
logger.debug(f"Initial parsing failed: {str(e)}")
# Continue to retry logic

# If we're here, the first attempt failed
if retries_left > 0:
retry_rm, retry_cb = new_group(
name="fix_output_format",
inputs={"output_string": output_string},
callbacks=callbacks,
)

# Use the fix_output_format prompt to try to fix the output
fixed_output_string = await fix_output_format_prompt.generate(
llm=llm,
data=OutputStringAndPrompt(
Expand All @@ -418,10 +432,72 @@ async def parse_output_string(
retries_left=retries_left - 1,
)
retry_rm.on_chain_end({"fixed_output_string": fixed_output_string})
result = super().parse(fixed_output_string.text)
else:
raise RagasOutputParserException()
return result

# Try to parse the fixed output
try:
fixed_jsonstr = extract_json(fixed_output_string.text)
result = super().parse(fixed_jsonstr)
return result
except (OutputParserException, json.JSONDecodeError) as e:
logger.debug(f"Parsing fixed output failed: {str(e)}")

# Last resort: Try to manually construct a valid JSON
# This is especially helpful for local LLMs that might not follow the exact format
try:
# Get the expected schema
schema = self.pydantic_object.model_json_schema()
required_fields = schema.get("required", [])
properties = schema.get("properties", {})

# Create a minimal valid JSON with default values
minimal_json = {}
for field in required_fields:
field_type = properties.get(field, {}).get("type")
if field_type == "string":
minimal_json[field] = "Unable to parse"
elif field_type == "integer":
minimal_json[field] = 0
elif field_type == "number":
minimal_json[field] = 0.0
elif field_type == "boolean":
minimal_json[field] = False
elif field_type == "array":
minimal_json[field] = []
elif field_type == "object":
minimal_json[field] = {}

# Try to parse with this minimal JSON
if minimal_json:
logger.warning(f"Using fallback minimal JSON: {minimal_json}")
result = self.pydantic_object.model_validate(minimal_json)
return result
except Exception as e:
logger.debug(f"Minimal JSON fallback failed: {str(e)}")

# If we've exhausted all retries and approaches
if retries_left > 1:
# Recursive call with one less retry
return await self.parse_output_string(
output_string=fixed_output_string.text,
prompt_value=prompt_value,
llm=llm,
callbacks=callbacks,
retries_left=retries_left - 1,
)

# If all attempts fail
raise RagasOutputParserException(details=f"Failed after {3-retries_left+1} attempts with output: {output_string[:100]}...")
except Exception as e:
if retries_left > 0:
logger.warning(f"Unexpected error during parsing: {str(e)}. Retrying...")
return await self.parse_output_string(
output_string=output_string,
prompt_value=prompt_value,
llm=llm,
callbacks=callbacks,
retries_left=retries_left - 1,
)
raise RagasOutputParserException(details=f"Unexpected error: {str(e)}")


# Ragas Adaptation
Expand Down
32 changes: 26 additions & 6 deletions src/ragas/prompt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,22 @@ def replace_string(s: str) -> str:

def extract_json(text: str) -> str:
"""Identify json from a text blob by matching '[]' or '{}'.
Enhanced to handle various LLM output formats, including markdown code blocks,
single quotes, and other common formatting issues.

Warning: This will identify the first json structure!"""

# check for markdown indicator; if present, start there
md_json_idx = text.find("```json")
if md_json_idx != -1:
text = text[md_json_idx:]
import re

# Check for any markdown code block (not just json-specific)
code_block_pattern = r"```(?:json)?\s*([\s\S]*?)```"
code_blocks = re.findall(code_block_pattern, text)

if code_blocks:
# Use the first code block that contains valid JSON markers
for block in code_blocks:
if ("{" in block and "}" in block) or ("[" in block and "]" in block):
text = block
break

# search for json delimiter pairs
left_bracket_idx = text.find("[")
Expand All @@ -101,6 +110,17 @@ def extract_json(text: str) -> str:

# When count returns to zero, we've found a complete structure
if count == 0:
return text[start_idx : i + 1]
json_str = text[start_idx : i + 1]

# Clean up common issues with JSON formatting from LLMs
# Replace single quotes with double quotes if needed
if "'" in json_str and '"' not in json_str:
json_str = json_str.replace("'", '"')

# Fix trailing commas in lists or objects which are invalid in JSON
json_str = re.sub(r',\s*}', '}', json_str)
json_str = re.sub(r',\s*]', ']', json_str)

return json_str

return text # In case of unbalanced JSON, return the original text
5 changes: 3 additions & 2 deletions src/ragas/run_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class RunConfig:
Parameters
----------
timeout : int, optional
Maximum time (in seconds) to wait for a single operation, by default 180.
Maximum time (in seconds) to wait for a single operation, by default 300.
For local LLMs, a higher timeout may be needed.
max_retries : int, optional
Maximum number of retry attempts, by default 10.
max_wait : int, optional
Expand All @@ -48,7 +49,7 @@ class RunConfig:
number generator using the specified seed.
"""

timeout: int = 180
timeout: int = 300 # Increased from 180 to 300 to accommodate slower local LLMs
max_retries: int = 10
max_wait: int = 60
max_workers: int = 16
Expand Down
106 changes: 106 additions & 0 deletions tests/unit/prompt/test_output_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytest
from pydantic import BaseModel, Field
import json

from ragas.prompt.utils import extract_json
from ragas.prompt.pydantic_prompt import RagasOutputParser
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompt_values import StringPromptValue


class TestModel(BaseModel):
"""Test model for output parsing tests."""
field1: str = Field(description="A string field")
field2: int = Field(description="An integer field")


class TestExtractJson:
"""Test the extract_json function with various input formats."""

def test_standard_json(self):
"""Test with standard JSON format."""
text = '{"field1": "value1", "field2": 42}'
result = extract_json(text)
assert result == text
# Verify it's valid JSON
parsed = json.loads(result)
assert parsed["field1"] == "value1"
assert parsed["field2"] == 42

def test_json_in_markdown(self):
"""Test with JSON in markdown code block."""
text = """
Here's the JSON:
```json
{"field1": "value1", "field2": 42}
```
"""
result = extract_json(text)
# Verify it's valid JSON
parsed = json.loads(result)
assert parsed["field1"] == "value1"
assert parsed["field2"] == 42

def test_json_with_single_quotes(self):
"""Test with JSON using single quotes instead of double quotes."""
text = "{'field1': 'value1', 'field2': 42}"
result = extract_json(text)
# Verify it's valid JSON after conversion
parsed = json.loads(result)
assert parsed["field1"] == "value1"
assert parsed["field2"] == 42

def test_json_with_trailing_comma(self):
"""Test with JSON containing trailing commas (invalid JSON but common in LLM outputs)."""
text = '{"field1": "value1", "field2": 42,}'
result = extract_json(text)
# Verify it's valid JSON after fixing
parsed = json.loads(result)
assert parsed["field1"] == "value1"
assert parsed["field2"] == 42

def test_json_in_text(self):
"""Test with JSON embedded in text."""
text = """
The answer is as follows:
{"field1": "value1", "field2": 42}
Hope this helps!
"""
result = extract_json(text)
# Verify it's valid JSON
parsed = json.loads(result)
assert parsed["field1"] == "value1"
assert parsed["field2"] == 42


@pytest.mark.asyncio
async def test_ragas_output_parser_fallback(mocker):
"""Test that RagasOutputParser can handle malformed JSON with fallback mechanism."""
# Create a parser
parser = RagasOutputParser(pydantic_object=TestModel)

# Mock the LLM to avoid actual calls
mock_llm = mocker.MagicMock()
mock_llm.generate = mocker.AsyncMock()

# Test with malformed JSON that should trigger the fallback
malformed_json = "This is not JSON at all"

# Mock the fix_output_format_prompt.generate to return something that's still not valid
mocker.patch(
"ragas.prompt.pydantic_prompt.fix_output_format_prompt.generate",
return_value=mocker.MagicMock(text="Still not valid JSON")
)

# The parser should use the fallback mechanism and return a valid model
result = await parser.parse_output_string(
output_string=malformed_json,
prompt_value=StringPromptValue(text="test prompt"),
llm=mock_llm,
callbacks=None,
)

# Verify we got a valid model with default values
assert isinstance(result, TestModel)
assert result.field1 == "Unable to parse"
assert result.field2 == 0
61 changes: 61 additions & 0 deletions tests/unit/test_timeout_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
import asyncio
from ragas.run_config import RunConfig
from ragas.metrics.base import Metric, SingleTurnMetric
from ragas.dataset_schema import SingleTurnSample
from typing import Optional, List


class TestTimeoutConfig:
"""Test the timeout configuration in RunConfig."""

def test_default_timeout(self):
"""Test that the default timeout is set to 300 seconds."""
config = RunConfig()
assert config.timeout == 300, "Default timeout should be 300 seconds"

def test_custom_timeout(self):
"""Test that a custom timeout can be set."""
config = RunConfig(timeout=500)
assert config.timeout == 500, "Custom timeout should be respected"


class SlowMetric(SingleTurnMetric):
"""A test metric that simulates slow processing."""

name = "slow_metric"

def __init__(self, sleep_time: float = 0.1):
super().__init__()
self.sleep_time = sleep_time

def init(self):
"""Initialize the metric."""
pass

async def _single_turn_ascore(self, sample: SingleTurnSample, callbacks=None) -> float:
"""Simulate slow processing by sleeping."""
await asyncio.sleep(self.sleep_time)
return 1.0


@pytest.mark.asyncio
async def test_metric_timeout():
"""Test that the timeout is applied to metric scoring."""
# Create a sample
sample = SingleTurnSample(
question="Test question",
answer="Test answer",
contexts=["Test context"]
)

# Create a slow metric
slow_metric = SlowMetric(sleep_time=0.2)

# Test with sufficient timeout
score = await slow_metric.single_turn_ascore(sample, timeout=0.5)
assert score == 1.0, "Metric should complete with sufficient timeout"

# Test with insufficient timeout
with pytest.raises(asyncio.TimeoutError):
await slow_metric.single_turn_ascore(sample, timeout=0.1)
Loading