Skip to content

Update agent #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ version = "0.0.1"
description = "An agent that populates and enriches custom schemas"
authors = [
{ name = "William Fu-Hinthorn", email = "[email protected]" },
{ name = "Lance Martin", email = "[email protected]" },
]
readme = "README.md"
license = { text = "MIT" }
Expand All @@ -16,6 +17,7 @@ dependencies = [
"langchain-fireworks>=0.1.7",
"python-dotenv>=1.0.1",
"langchain-community>=0.2.13",
"tavily-python",
]

[project.optional-dependencies]
Expand Down
7 changes: 7 additions & 0 deletions src/enrichment_agent/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ class Configuration:
},
)

num_search_queries: int = field(
default=2,
metadata={
"description": "The number of search queries to generate for web search."
},
)

@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
Expand Down
6 changes: 3 additions & 3 deletions src/enrichment_agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enrichment_agent import prompts
from enrichment_agent.configuration import Configuration
from enrichment_agent.state import InputState, OutputState, State
from enrichment_agent.tools import scrape_website, search
from enrichment_agent.tools import perform_web_research
from enrichment_agent.utils import init_model


Expand Down Expand Up @@ -51,7 +51,7 @@ async def call_agent_model(

# Initialize the raw model with the provided configuration and bind the tools
raw_model = init_model(config)
model = raw_model.bind_tools([scrape_website, search, info_tool], tool_choice="any")
model = raw_model.bind_tools([perform_web_research, info_tool], tool_choice="any")
response = cast(AIMessage, await model.ainvoke(messages))

# Initialize info to None
Expand Down Expand Up @@ -219,7 +219,7 @@ def route_after_checker(
)
workflow.add_node(call_agent_model)
workflow.add_node(reflect)
workflow.add_node("tools", ToolNode([search, scrape_website]))
workflow.add_node("tools", ToolNode([perform_web_research]))
workflow.add_edge("__start__", "call_agent_model")
workflow.add_conditional_edges("call_agent_model", route_after_agent)
workflow.add_edge("tools", "call_agent_model")
Expand Down
10 changes: 2 additions & 8 deletions src/enrichment_agent/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,15 @@
from langchain_core.messages import BaseMessage
from langgraph.graph import add_messages


@dataclass(kw_only=True)
class InputState:
"""Input state defines the interface between the graph and the user (external API)."""

topic: str
"The topic for which the agent is tasked to gather information."
companies: list[str]
"List of company names to research."

extraction_schema: dict[str, Any]
"The json schema defines the information the agent is tasked with filling out."

info: Optional[dict[str, Any]] = field(default=None)
"The info state tracks the current extracted data for the given topic, conforming to the provided schema. This is primarily populated by the agent."


@dataclass(kw_only=True)
class State(InputState):
"""A graph's State defines three main things.
Expand Down
137 changes: 99 additions & 38 deletions src/enrichment_agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,70 +5,131 @@
Users can edit and extend these tools as needed.
"""

import asyncio
import json
from typing import Any, Optional, cast

import aiohttp
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import InjectedToolArg
from langgraph.prebuilt import InjectedState
from typing_extensions import Annotated
from typing_extensions import Annotated, List

from pydantic import BaseModel, Field

from enrichment_agent.configuration import Configuration
from enrichment_agent.state import State
from enrichment_agent.utils import init_model
from enrichment_agent.utils import init_model, deduplicate_and_format_sources

from tavily import AsyncTavilyClient

async def search(
query: str, *, config: Annotated[RunnableConfig, InjectedToolArg]
) -> Optional[list[dict[str, Any]]]:
"""Query a search engine.
# Schemas
class SearchQuery(BaseModel):
search_query: str = Field(None, description="Query for web search.")

This function queries the web to fetch comprehensive, accurate, and trusted results. It's particularly useful
for answering questions about current events. Provide as much context in the query as needed to ensure high recall.
"""
configuration = Configuration.from_runnable_config(config)
wrapped = TavilySearchResults(max_results=configuration.max_search_results)
result = await wrapped.ainvoke({"query": query})
return cast(list[dict[str, Any]], result)
class Queries(BaseModel):
queries: List[SearchQuery] = Field(
description="List of search queries.",
)

# Instructions
query_writer_instructions = """You are a search query generator tasked with creating diverse but related search queries based on an initial query.

For the initial query: {initial_query}

Generate distinct search queries that:
1. Cover different aspects or angles of the topic
2. Are mutually exclusive to avoid redundant results
3. Help gather comprehensive information about the subject
4. Use different phrasings and keywords to maximize coverage
5. Maintain relevance to the original query intent

_INFO_PROMPT = """You are doing web research on behalf of a user. You are trying to find out this information:
Each query should focus on a unique aspect while staying connected to the main topic."""

<info>
_INFO_PROMPT = """You are doing web research on behalf of a user. You need to extract specific information based on this schema:

<schema>
{info}
</info>
</schema>

You just scraped the following website: {url}
You have just scraped website content. Review the content below and take detailed notes that align with the extraction schema above.

Based on the website content below, jot down some notes about the website.
Focus only on information that matches the schema requirements.

<Website content>
<Website contents>
{content}
</Website content>"""
</Website contents>

Please provide well structured notes that:
1. Map directly to the schema fields
2. Include only relevant information from the content
3. Maintain the original facts and data points
4. Note any missing schema fields that weren't found in the content"""

async def scrape_website(
url: str,
*,
async def perform_web_research(
query: str,
*,
state: Annotated[State, InjectedState],
config: Annotated[RunnableConfig, InjectedToolArg],
config: Annotated[RunnableConfig, InjectedToolArg]
) -> str:
"""Scrape and summarize content from a given URL.

Returns:
str: A summary of the scraped content, tailored to the extraction schema.
"""Execute a multi-step web search and information extraction process.

This function performs the following steps:
1. Generates multiple search queries based on the input query
2. Executes concurrent web searches using the Tavily API
3. Deduplicates and formats the search results
4. Extracts structured information based on the provided schema

Args:
query: The initial search query string
state: Injected application state containing the extraction schema
config: Runtime configuration for the search process

Returns:
str: Structured notes from the search results that are
relevant to the extraction schema in state.extraction_schema

Note:
The function uses concurrent execution for multiple search queries to improve
performance and combines results from various sources for comprehensive coverage.
"""
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
content = await response.text()
configuration = Configuration.from_runnable_config(config)

# Generate search queries
raw_model = init_model(config)
structured_llm = raw_model.with_structured_output(Queries)

# Format system instructions
query_instructions = query_writer_instructions.format(initial_query=query)

# Generate queries
results = structured_llm.invoke([SystemMessage(content=query_instructions)]+[HumanMessage(content=f"Please generate {configuration.num_search_queries} search queries.")])

# Search client
tavily_async_client = AsyncTavilyClient()

# Web search
query_list = [query.search_query for query in results.queries]
search_tasks = []
for query in query_list:
search_tasks.append(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the only downside is that it's harder to switch to a different search tool that doesn't return raw content, but this is fine for now

tavily_async_client.search(
query,
max_results=configuration.max_search_results,
include_raw_content=True,
topic="general"
)
)

# Execute all searches concurrently
search_docs = await asyncio.gather(*search_tasks)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!


# Deduplicate and format sources
source_str = deduplicate_and_format_sources(search_docs, max_tokens_per_source=1000, include_raw_content=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably not needed for v0, but also would be great to propagate sources for each extracted datapoint

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps can be just included in the extraction schema...


# Generate structured notes relevant to the extraction schema
p = _INFO_PROMPT.format(
info=json.dumps(state.extraction_schema, indent=2),
url=url,
content=content[:40_000],
content=source_str,
)
raw_model = init_model(config)
result = await raw_model.ainvoke(p)
return str(result.content)
return str(result.content)
52 changes: 52 additions & 0 deletions src/enrichment_agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,55 @@ def init_model(config: Optional[RunnableConfig] = None) -> BaseChatModel:
provider = None
model = fully_specified_name
return init_chat_model(model, model_provider=provider)

def deduplicate_and_format_sources(search_response, max_tokens_per_source, include_raw_content=True):
"""
Takes either a single search response or list of responses from Tavily API and formats them.
Limits the raw_content to approximately max_tokens_per_source.
include_raw_content specifies whether to include the raw_content from Tavily in the formatted string.

Args:
search_response: Either:
- A dict with a 'results' key containing a list of search results
- A list of dicts, each containing search results

Returns:
str: Formatted string with deduplicated sources
"""
# Convert input to list of results
if isinstance(search_response, dict):
sources_list = search_response['results']
elif isinstance(search_response, list):
sources_list = []
for response in search_response:
if isinstance(response, dict) and 'results' in response:
sources_list.extend(response['results'])
else:
sources_list.extend(response)
else:
raise ValueError("Input must be either a dict with 'results' or a list of search results")

# Deduplicate by URL
unique_sources = {}
for source in sources_list:
if source['url'] not in unique_sources:
unique_sources[source['url']] = source

# Format output
formatted_text = "Sources:\n\n"
for i, source in enumerate(unique_sources.values(), 1):
formatted_text += f"Source {source['title']}:\n===\n"
formatted_text += f"URL: {source['url']}\n===\n"
formatted_text += f"Most relevant content from source: {source['content']}\n===\n"
if include_raw_content:
# Using rough estimate of 4 characters per token
char_limit = max_tokens_per_source * 4
# Handle None raw_content
raw_content = source.get('raw_content', '')
if raw_content is None:
raw_content = ''
if len(raw_content) > char_limit:
raw_content = raw_content[:char_limit] + "... [truncated]"
formatted_text += f"Full source content limited to {max_tokens_per_source} tokens: {raw_content}\n\n"

return formatted_text.strip()
Loading