diff --git a/agent_baselines/solvers/sqa/formatted_perplexity.py b/agent_baselines/solvers/sqa/formatted_perplexity.py index e82e9e4..d6cc075 100644 --- a/agent_baselines/solvers/sqa/formatted_perplexity.py +++ b/agent_baselines/solvers/sqa/formatted_perplexity.py @@ -1,6 +1,7 @@ import logging from inspect_ai.solver import Solver, chain, solver +from inspect_ai.model import ChatMessageAssistant from agent_baselines.solvers.sqa.format_solver import format_solver from agent_baselines.solvers.sqa.perplexity_base import perplexity_solver @@ -8,10 +9,35 @@ logger = logging.getLogger(__name__) +@solver +def add_perplexity_references() -> Solver: + + async def solve(state, generate): + m = state.messages[-1] + if not isinstance(m, ChatMessageAssistant): + raise ValueError("The last message must be from the assistant.") + + citations = m.content[-1].citations + + references_str = '\n\n## References\n\nEach reference below is in the format "[citation ID] {title} ({url})". Use `title` as the `excerpt`' + for idx, citation in enumerate(citations): + title, url = citation.title, citation.url + references_str += f"\n[{idx + 1}] {title} ({url})" + + state.messages.append(ChatMessageAssistant(content=m.text + references_str)) + return state + + return solve + + @solver def formatted_solver( system_prompt: str | None = None, search_context_size: str | None = None, + reasoning_effort: str = "", + search_mode: str = "", + require_snippets: bool = True, + scorer_model: str = "google/gemini-2.5-flash-preview-05-20", ) -> Solver: chainlist = [ perplexity_solver( @@ -19,7 +45,10 @@ def formatted_solver( prompt_template="{prompt_without_formatting_instructions}", system_message=system_prompt, search_context_size=search_context_size, + reasoning_effort=reasoning_effort, + search_mode=search_mode, ), - format_solver("google/gemini-2.5-flash-preview-05-20"), + add_perplexity_references(), + format_solver(scorer_model, require_snippets=require_snippets), ] return chain(chainlist) diff --git a/agent_baselines/solvers/sqa/perplexity_base.py b/agent_baselines/solvers/sqa/perplexity_base.py index 7d53cfd..ee10bdb 100644 --- a/agent_baselines/solvers/sqa/perplexity_base.py +++ b/agent_baselines/solvers/sqa/perplexity_base.py @@ -51,9 +51,10 @@ def perplexity_solver( system_message: str | None = None, use_structured_decoding: bool = False, search_context_size: str | None = None, - search_mode: str = "academic", + search_mode: str = "", # Date format can be flexible (e.g., '3/1/2025', 'March 1, 2025'). search_before_date_filter: str | None = None, + reasoning_effort: str = "high", ) -> Solver: # Verify that we have a PerplexityAPI model model = get_model() @@ -82,7 +83,10 @@ def perplexity_solver( "search_context_size": search_context_size } - extra_body["search_mode"] = search_mode + if search_mode: + extra_body["search_mode"] = search_mode + if reasoning_effort: + extra_body["reasoning_effort"] = reasoning_effort if search_before_date_filter is not None: extra_body["search_before_date_filter"] = search_before_date_filter