diff --git a/src/cleanlab_codex/validator.py b/src/cleanlab_codex/validator.py index 2710157..8921d8e 100644 --- a/src/cleanlab_codex/validator.py +++ b/src/cleanlab_codex/validator.py @@ -4,7 +4,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Optional, cast +import asyncio +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, cast from cleanlab_tlm import TrustworthyRAG from pydantic import BaseModel, Field, field_validator @@ -17,6 +18,7 @@ from cleanlab_codex.project import Project if TYPE_CHECKING: + from cleanlab_codex.types.entry import Entry from cleanlab_codex.types.validator import ThresholdedTrustworthyRAGScore @@ -94,6 +96,8 @@ def validate( query: str, context: str, response: str, + *, + run_async: bool = False, prompt: Optional[str] = None, form_prompt: Optional[Callable[[str, str], str]] = None, ) -> dict[str, Any]: @@ -104,6 +108,7 @@ def validate( query (str): The user query that was used to generate the response. context (str): The context that was retrieved from the RAG Knowledge Base and used to generate the response. response (str): A reponse from your LLM/RAG system. + run_async (bool): If True, runs detect asynchronously prompt (str, optional): Optional prompt representing the actual inputs (combining query, context, and system instructions into one string) to the LLM that generated the response. form_prompt (Callable[[str, str], str], optional): Optional function to format the prompt based on query and context. Cannot be provided together with prompt, provide one or the other. This function should take query and context as parameters and return a formatted prompt string. If not provided, a default prompt formatter will be used. To include a system prompt or any other special instructions for your LLM, incorporate them directly in your custom form_prompt() function definition. @@ -113,10 +118,32 @@ def validate( - 'is_bad_response': True if the response is flagged as potentially bad, False otherwise. When True, a Codex lookup is performed, which logs this query into the Codex Project for SMEs to answer. - Additional keys from a [`ThresholdedTrustworthyRAGScore`](/codex/api/python/types.validator/#class-thresholdedtrustworthyragscore) dictionary: each corresponds to a [TrustworthyRAG](/tlm/api/python/utils.rag/#class-trustworthyrag) evaluation metric, and points to the score for this evaluation as well as a boolean `is_bad` flagging whether the score falls below the corresponding threshold. """ - scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt) - expert_answer = None - if is_bad_response: - expert_answer = self._remediate(query) + if run_async: + try: + loop = asyncio.get_running_loop() + except RuntimeError: # No running loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + expert_task = loop.create_task(self.remediate_async(query)) + detect_task = loop.run_in_executor(None, self.detect, query, context, response, prompt, form_prompt) + expert_answer, maybe_entry = loop.run_until_complete(expert_task) + scores, is_bad_response = loop.run_until_complete(detect_task) + loop.close() + if is_bad_response: + if expert_answer is None: + # TODO: Make this async as well + project_id = self._project._id # noqa: SLF001 + self._project._sdk_client.projects.entries.add_question( # noqa: SLF001 + project_id, + question=query, + ).model_dump() + else: + expert_answer = None + else: + scores, is_bad_response = self.detect(query, context, response, prompt, form_prompt) + expert_answer = None + if is_bad_response: + expert_answer = self._remediate(query) return { "expert_answer": expert_answer, @@ -181,6 +208,10 @@ def _remediate(self, query: str) -> str | None: codex_answer, _ = self._project.query(question=query) return codex_answer + async def remediate_async(self, query: str) -> Tuple[Optional[str], Optional[Entry]]: + codex_answer, entry = self._project.query(question=query, read_only=True) + return codex_answer, entry + class BadResponseThresholds(BaseModel): """Config for determining if a response is bad.