Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,14 @@ llm:
search:
tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com)
tavily_api_base_url: "https://api.tavily.com" # Tavily API URL
max_results: 10 # Max search results
max_pages: 5 # Max pages to scrape
max_searches: 4 # Max search operations
max_results: 10 # Max results in search query
content_limit: 1500 # Content char limit per source

# Execution Settings
execution:
max_steps: 6 # Max execution steps
max_clarifications: 3 # Max clarification requests
max_iterations: 10 # Max iterations per step
max_searches: 4 # Max search operations
max_iterations: 10 # Max agent iterations
mcp_context_limit: 15000 # Max context length from MCP server response
logs_dir: "logs" # Directory for saving agent execution logs
reports_dir: "reports" # Directory for saving agent reports
Expand Down
10 changes: 5 additions & 5 deletions sgr_deep_research/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def get_agents_list():

@router.get("/v1/models")
async def get_available_models():
"""Get list of available agent models."""
"""Get a list of available agent models."""
models_data = [
{
"id": agent_def.name,
Expand Down Expand Up @@ -94,7 +94,7 @@ async def provide_clarification(agent_id: str, request: ClarificationRequest):
await agent.provide_clarification(request.clarifications)
return StreamingResponse(
agent.streaming_generator.stream(),
media_type="text/plain",
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
Expand All @@ -108,8 +108,8 @@ async def provide_clarification(agent_id: str, request: ClarificationRequest):


def _is_agent_id(model_str: str) -> bool:
"""Check if model string is an agent ID (contains underscore and UUID-like
format)."""
"""Check if the model string is an agent ID (contains underscore and UUID-
like format)."""
return "_" in model_str and len(model_str) > 20


Expand Down Expand Up @@ -148,7 +148,7 @@ async def create_chat_completion(request: ChatCompletionRequest):
_ = asyncio.create_task(agent.execute())
return StreamingResponse(
agent.streaming_generator.stream(),
media_type="text/plain",
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
Expand Down
6 changes: 2 additions & 4 deletions sgr_deep_research/core/agent_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ class SearchConfig(BaseModel):
tavily_api_key: str | None = Field(default=None, description="Tavily API key")
tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL")

max_searches: int = Field(default=4, ge=0, description="Maximum number of searches")
max_results: int = Field(default=10, ge=1, description="Maximum number of search results")
max_pages: int = Field(default=5, gt=0, description="Maximum pages to scrape")
content_limit: int = Field(default=1500, gt=0, description="Content character limit per source")
content_limit: int = Field(default=3500, gt=0, description="Content character limit per source")


class PromptsConfig(BaseModel):
Expand Down Expand Up @@ -99,10 +99,8 @@ class ExecutionConfig(BaseModel, extra="allow"):
You can add any additional fields as needed.
"""

max_steps: int = Field(default=6, gt=0, description="Maximum number of execution steps")
max_clarifications: int = Field(default=3, ge=0, description="Maximum number of clarifications")
max_iterations: int = Field(default=10, gt=0, description="Maximum number of iterations")
max_searches: int = Field(default=4, ge=0, description="Maximum number of searches")
mcp_context_limit: int = Field(default=15000, gt=0, description="Maximum context length from MCP server response")

logs_dir: str = Field(default="logs", description="Directory for saving bot logs")
Expand Down
7 changes: 3 additions & 4 deletions sgr_deep_research/core/agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def create(cls, agent_def: AgentDefinition, task: str) -> Agent:
)
logger.error(error_msg)
raise ValueError(error_msg)
mcp_tools = await MCP2ToolConverter.build_tools_from_mcp(agent_def.mcp)
mcp_tools: list = await MCP2ToolConverter.build_tools_from_mcp(agent_def.mcp)

tools = [*mcp_tools]
for tool in agent_def.tools:
Expand All @@ -89,11 +89,10 @@ async def create(cls, agent_def: AgentDefinition, task: str) -> Agent:
try:
agent = BaseClass(
task=task,
def_name=agent_def.name,
toolkit=tools,
openai_client=cls._create_client(agent_def.llm),
llm_config=agent_def.llm,
execution_config=agent_def.execution,
prompts_config=agent_def.prompts,
agent_config=agent_def,
)
logger.info(
f"Created agent '{agent_def.name}' "
Expand Down
36 changes: 19 additions & 17 deletions sgr_deep_research/core/agents/sgr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from openai import AsyncOpenAI

from sgr_deep_research.core.agent_definition import ExecutionConfig, LLMConfig, PromptsConfig
from sgr_deep_research.core.agent_definition import AgentConfig
from sgr_deep_research.core.base_agent import BaseAgent
from sgr_deep_research.core.tools import (
BaseTool,
Expand All @@ -16,61 +16,63 @@


class SGRAgent(BaseAgent):
"""Agent for deep research tasks using SGR framework."""
"""Agent for deep research tasks using an SGR framework."""

name: str = "sgr_agent"

def __init__(
self,
task: str,
openai_client: AsyncOpenAI,
llm_config: LLMConfig,
prompts_config: PromptsConfig,
execution_config: ExecutionConfig,
toolkit: list[Type[BaseTool]] | None = None,
agent_config: AgentConfig,
toolkit: list[Type[BaseTool]],
def_name: str | None = None,
**kwargs: dict,
):
super().__init__(
task=task,
openai_client=openai_client,
llm_config=llm_config,
prompts_config=prompts_config,
execution_config=execution_config,
agent_config=agent_config,
toolkit=toolkit,
def_name=def_name,
**kwargs,
)
self.max_searches = execution_config.max_searches

async def _prepare_tools(self) -> Type[NextStepToolStub]:
"""Prepare tool classes with current context limits."""
tools = set(self.toolkit)
if self._context.iteration >= self.max_iterations:
if self._context.iteration >= self.config.execution.max_iterations:
tools = {
CreateReportTool,
FinalAnswerTool,
}
if self._context.clarifications_used >= self.max_clarifications:
if self._context.clarifications_used >= self.config.execution.max_clarifications:
tools -= {
ClarificationTool,
}
if self._context.searches_used >= self.max_searches:
if self._context.searches_used >= self.config.search.max_searches:
tools -= {
WebSearchTool,
}
return NextStepToolsBuilder.build_NextStepTools(list(tools))

async def _reasoning_phase(self) -> NextStepToolStub:
async with self.openai_client.chat.completions.stream(
model=self.llm_config.model,
model=self.config.llm.model,
response_format=await self._prepare_tools(),
messages=await self._prepare_context(),
max_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
max_tokens=self.config.llm.max_tokens,
temperature=self.config.llm.temperature,
) as stream:
async for event in stream:
if event.type == "chunk":
self.streaming_generator.add_chunk(event.chunk)
reasoning: NextStepToolStub = (await stream.get_final_completion()).choices[0].message.parsed # type: ignore
# we are not fully sure if it should be in conversation or not. Looks like not necessary data
# self.conversation.append({"role": "assistant", "content": reasoning.model_dump_json(exclude={"function"})})
self.streaming_generator.add_tool_call(
f"{self._context.iteration}-reasoning", reasoning.tool_name, reasoning.model_dump_json(exclude={"function"})
)
self._log_reasoning(reasoning)
return reasoning

Expand Down Expand Up @@ -100,7 +102,7 @@ async def _select_action_phase(self, reasoning: NextStepToolStub) -> BaseTool:
return tool

async def _action_phase(self, tool: BaseTool) -> str:
result = await tool(self._context)
result = await tool(self._context, self.config)
self.conversation.append(
{"role": "tool", "content": result, "tool_call_id": f"{self._context.iteration}-action"}
)
Expand Down
35 changes: 19 additions & 16 deletions sgr_deep_research/core/agents/sgr_so_tool_calling_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,34 @@

from openai import AsyncOpenAI

from sgr_deep_research.core.agent_definition import ExecutionConfig, LLMConfig, PromptsConfig
from sgr_deep_research.core.agent_definition import AgentConfig
from sgr_deep_research.core.agents.sgr_tool_calling_agent import SGRToolCallingAgent
from sgr_deep_research.core.base_tool import BaseTool
from sgr_deep_research.core.tools import ReasoningTool


class SGRSOToolCallingAgent(SGRToolCallingAgent):
"""Agent that uses OpenAI native function calling to select and execute
tools based on SGR like reasoning scheme."""
tools based on SGR like a reasoning scheme."""

name: str = "sgr_so_tool_calling_agent"

def __init__(
self,
task: str,
openai_client: AsyncOpenAI,
llm_config: LLMConfig,
prompts_config: PromptsConfig,
execution_config: ExecutionConfig,
toolkit: list[Type[BaseTool]] | None = None,
agent_config: AgentConfig,
toolkit: list[Type[BaseTool]],
def_name: str | None = None,
**kwargs: dict,
):
super().__init__(
task=task,
openai_client=openai_client,
llm_config=llm_config,
prompts_config=prompts_config,
execution_config=execution_config,
agent_config=agent_config,
toolkit=toolkit,
def_name=def_name,
**kwargs,
)
warn(
"SGRSOToolCallingAgent is deprecated and will be removed in the future. "
Expand All @@ -40,10 +40,10 @@ def __init__(

async def _reasoning_phase(self) -> ReasoningTool:
async with self.openai_client.chat.completions.stream(
model=self.llm_config.model,
model=self.config.llm.model,
messages=await self._prepare_context(),
max_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
max_tokens=self.config.llm.max_tokens,
temperature=self.config.llm.temperature,
tools=await self._prepare_tools(),
tool_choice={"type": "function", "function": {"name": ReasoningTool.tool_name}},
) as stream:
Expand All @@ -54,11 +54,11 @@ async def _reasoning_phase(self) -> ReasoningTool:
(await stream.get_final_completion()).choices[0].message.tool_calls[0].function.parsed_arguments #
)
async with self.openai_client.chat.completions.stream(
model=self.llm_config.model,
model=self.config.llm.model,
response_format=ReasoningTool,
messages=await self._prepare_context(),
max_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
max_tokens=self.config.llm.max_tokens,
temperature=self.config.llm.temperature,
) as stream:
async for event in stream:
if event.type == "chunk":
Expand All @@ -75,12 +75,15 @@ async def _reasoning_phase(self) -> ReasoningTool:
"id": f"{self._context.iteration}-reasoning",
"function": {
"name": reasoning.tool_name,
"arguments": "{}",
"arguments": reasoning.model_dump_json(),
},
}
],
}
)
self.streaming_generator.add_tool_call(
f"{self._context.iteration}-reasoning", reasoning.tool_name, reasoning.model_dump_json()
)
self.conversation.append(
{"role": "tool", "content": tool_call_result, "tool_call_id": f"{self._context.iteration}-reasoning"}
)
Expand Down
42 changes: 23 additions & 19 deletions sgr_deep_research/core/agents/sgr_tool_calling_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from openai import AsyncOpenAI, pydantic_function_tool
from openai.types.chat import ChatCompletionFunctionToolParam

from sgr_deep_research.core.agent_definition import ExecutionConfig, LLMConfig, PromptsConfig
from sgr_deep_research.core.agent_definition import AgentConfig
from sgr_deep_research.core.agents.sgr_agent import SGRAgent
from sgr_deep_research.core.models import AgentStatesEnum
from sgr_deep_research.core.tools import (
Expand All @@ -26,47 +26,47 @@ def __init__(
self,
task: str,
openai_client: AsyncOpenAI,
llm_config: LLMConfig,
prompts_config: PromptsConfig,
execution_config: ExecutionConfig,
toolkit: list[Type[BaseTool]] | None = None,
agent_config: AgentConfig,
toolkit: list[Type[BaseTool]],
def_name: str | None = None,
**kwargs: dict,
):
super().__init__(
task=task,
openai_client=openai_client,
llm_config=llm_config,
prompts_config=prompts_config,
execution_config=execution_config,
agent_config=agent_config,
toolkit=toolkit,
def_name=def_name,
**kwargs,
)
self.toolkit.append(ReasoningTool)
self.tool_choice: Literal["required"] = "required"

async def _prepare_tools(self) -> list[ChatCompletionFunctionToolParam]:
"""Prepare available tools for current agent state and progress."""
"""Prepare available tools for the current agent state and progress."""
tools = set(self.toolkit)
if self._context.iteration >= self.max_iterations:
if self._context.iteration >= self.config.execution.max_iterations:
tools = {
ReasoningTool,
CreateReportTool,
FinalAnswerTool,
}
if self._context.clarifications_used >= self.max_clarifications:
if self._context.clarifications_used >= self.config.execution.max_clarifications:
tools -= {
ClarificationTool,
}
if self._context.searches_used >= self.max_searches:
if self._context.searches_used >= self.config.search.max_searches:
tools -= {
WebSearchTool,
}
return [pydantic_function_tool(tool, name=tool.tool_name, description="") for tool in tools]

async def _reasoning_phase(self) -> ReasoningTool:
async with self.openai_client.chat.completions.stream(
model=self.llm_config.model,
model=self.config.llm.model,
messages=await self._prepare_context(),
max_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
max_tokens=self.config.llm.max_tokens,
temperature=self.config.llm.temperature,
tools=await self._prepare_tools(),
tool_choice={"type": "function", "function": {"name": ReasoningTool.tool_name}},
) as stream:
Expand All @@ -93,6 +93,9 @@ async def _reasoning_phase(self) -> ReasoningTool:
}
)
tool_call_result = await reasoning(self._context)
self.streaming_generator.add_tool_call(
f"{self._context.iteration}-reasoning", reasoning.tool_name, tool_call_result
)
self.conversation.append(
{"role": "tool", "content": tool_call_result, "tool_call_id": f"{self._context.iteration}-reasoning"}
)
Expand All @@ -101,10 +104,10 @@ async def _reasoning_phase(self) -> ReasoningTool:

async def _select_action_phase(self, reasoning: ReasoningTool) -> BaseTool:
async with self.openai_client.chat.completions.stream(
model=self.llm_config.model,
model=self.config.llm.model,
messages=await self._prepare_context(),
max_tokens=self.llm_config.max_tokens,
temperature=self.llm_config.temperature,
max_tokens=self.config.llm.max_tokens,
temperature=self.config.llm.temperature,
tools=await self._prepare_tools(),
tool_choice=self.tool_choice,
) as stream:
Expand All @@ -121,7 +124,8 @@ async def _select_action_phase(self, reasoning: ReasoningTool) -> BaseTool:
final_content = completion.choices[0].message.content or "Task completed successfully"
tool = FinalAnswerTool(
reasoning="Agent decided to complete the task",
completed_steps=[final_content],
completed_steps=[],
answer=final_content,
status=AgentStatesEnum.COMPLETED,
)
if not isinstance(tool, BaseTool):
Expand Down
Loading