diff --git a/pyproject.toml b/pyproject.toml
index 526ec91..cfbb2d5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,6 +4,7 @@ version = "0.0.1"
description = "An agent that populates and enriches custom schemas"
authors = [
{ name = "William Fu-Hinthorn", email = "13333726+hinthornw@users.noreply.github.com" },
+ { name = "Lance Martin", email = "lance@langchain.dev" },
]
readme = "README.md"
license = { text = "MIT" }
@@ -16,6 +17,7 @@ dependencies = [
"langchain-fireworks>=0.1.7",
"python-dotenv>=1.0.1",
"langchain-community>=0.2.13",
+ "tavily-python",
]
[project.optional-dependencies]
diff --git a/src/enrichment_agent/configuration.py b/src/enrichment_agent/configuration.py
index 2780be6..158b102 100644
--- a/src/enrichment_agent/configuration.py
+++ b/src/enrichment_agent/configuration.py
@@ -51,6 +51,13 @@ class Configuration:
},
)
+ num_search_queries: int = field(
+ default=2,
+ metadata={
+ "description": "The number of search queries to generate for web search."
+ },
+ )
+
@classmethod
def from_runnable_config(
cls, config: Optional[RunnableConfig] = None
diff --git a/src/enrichment_agent/graph.py b/src/enrichment_agent/graph.py
index 64ef445..47d8312 100644
--- a/src/enrichment_agent/graph.py
+++ b/src/enrichment_agent/graph.py
@@ -15,7 +15,7 @@
from enrichment_agent import prompts
from enrichment_agent.configuration import Configuration
from enrichment_agent.state import InputState, OutputState, State
-from enrichment_agent.tools import scrape_website, search
+from enrichment_agent.tools import perform_web_research
from enrichment_agent.utils import init_model
@@ -51,7 +51,7 @@ async def call_agent_model(
# Initialize the raw model with the provided configuration and bind the tools
raw_model = init_model(config)
- model = raw_model.bind_tools([scrape_website, search, info_tool], tool_choice="any")
+ model = raw_model.bind_tools([perform_web_research, info_tool], tool_choice="any")
response = cast(AIMessage, await model.ainvoke(messages))
# Initialize info to None
@@ -219,7 +219,7 @@ def route_after_checker(
)
workflow.add_node(call_agent_model)
workflow.add_node(reflect)
-workflow.add_node("tools", ToolNode([search, scrape_website]))
+workflow.add_node("tools", ToolNode([perform_web_research]))
workflow.add_edge("__start__", "call_agent_model")
workflow.add_conditional_edges("call_agent_model", route_after_agent)
workflow.add_edge("tools", "call_agent_model")
diff --git a/src/enrichment_agent/state.py b/src/enrichment_agent/state.py
index ada0b8f..8f010fa 100644
--- a/src/enrichment_agent/state.py
+++ b/src/enrichment_agent/state.py
@@ -11,21 +11,15 @@
from langchain_core.messages import BaseMessage
from langgraph.graph import add_messages
-
@dataclass(kw_only=True)
class InputState:
"""Input state defines the interface between the graph and the user (external API)."""
-
- topic: str
- "The topic for which the agent is tasked to gather information."
+ companies: list[str]
+ "List of company names to research."
extraction_schema: dict[str, Any]
"The json schema defines the information the agent is tasked with filling out."
- info: Optional[dict[str, Any]] = field(default=None)
- "The info state tracks the current extracted data for the given topic, conforming to the provided schema. This is primarily populated by the agent."
-
-
@dataclass(kw_only=True)
class State(InputState):
"""A graph's State defines three main things.
diff --git a/src/enrichment_agent/tools.py b/src/enrichment_agent/tools.py
index f756e31..5a4c033 100644
--- a/src/enrichment_agent/tools.py
+++ b/src/enrichment_agent/tools.py
@@ -5,70 +5,131 @@
Users can edit and extend these tools as needed.
"""
+import asyncio
import json
-from typing import Any, Optional, cast
-import aiohttp
-from langchain_community.tools.tavily_search import TavilySearchResults
+from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import InjectedToolArg
from langgraph.prebuilt import InjectedState
-from typing_extensions import Annotated
+from typing_extensions import Annotated, List
+
+from pydantic import BaseModel, Field
from enrichment_agent.configuration import Configuration
from enrichment_agent.state import State
-from enrichment_agent.utils import init_model
+from enrichment_agent.utils import init_model, deduplicate_and_format_sources
+from tavily import AsyncTavilyClient
-async def search(
- query: str, *, config: Annotated[RunnableConfig, InjectedToolArg]
-) -> Optional[list[dict[str, Any]]]:
- """Query a search engine.
+# Schemas
+class SearchQuery(BaseModel):
+ search_query: str = Field(None, description="Query for web search.")
- This function queries the web to fetch comprehensive, accurate, and trusted results. It's particularly useful
- for answering questions about current events. Provide as much context in the query as needed to ensure high recall.
- """
- configuration = Configuration.from_runnable_config(config)
- wrapped = TavilySearchResults(max_results=configuration.max_search_results)
- result = await wrapped.ainvoke({"query": query})
- return cast(list[dict[str, Any]], result)
+class Queries(BaseModel):
+ queries: List[SearchQuery] = Field(
+ description="List of search queries.",
+ )
+
+# Instructions
+query_writer_instructions = """You are a search query generator tasked with creating diverse but related search queries based on an initial query.
+
+For the initial query: {initial_query}
+Generate distinct search queries that:
+1. Cover different aspects or angles of the topic
+2. Are mutually exclusive to avoid redundant results
+3. Help gather comprehensive information about the subject
+4. Use different phrasings and keywords to maximize coverage
+5. Maintain relevance to the original query intent
-_INFO_PROMPT = """You are doing web research on behalf of a user. You are trying to find out this information:
+Each query should focus on a unique aspect while staying connected to the main topic."""
-
+_INFO_PROMPT = """You are doing web research on behalf of a user. You need to extract specific information based on this schema:
+
+
{info}
-
+
-You just scraped the following website: {url}
+You have just scraped website content. Review the content below and take detailed notes that align with the extraction schema above.
-Based on the website content below, jot down some notes about the website.
+Focus only on information that matches the schema requirements.
-
+
{content}
-"""
+
+Please provide well structured notes that:
+1. Map directly to the schema fields
+2. Include only relevant information from the content
+3. Maintain the original facts and data points
+4. Note any missing schema fields that weren't found in the content"""
-async def scrape_website(
- url: str,
- *,
+async def perform_web_research(
+ query: str,
+ *,
state: Annotated[State, InjectedState],
- config: Annotated[RunnableConfig, InjectedToolArg],
+ config: Annotated[RunnableConfig, InjectedToolArg]
) -> str:
- """Scrape and summarize content from a given URL.
-
- Returns:
- str: A summary of the scraped content, tailored to the extraction schema.
+ """Execute a multi-step web search and information extraction process.
+
+ This function performs the following steps:
+ 1. Generates multiple search queries based on the input query
+ 2. Executes concurrent web searches using the Tavily API
+ 3. Deduplicates and formats the search results
+ 4. Extracts structured information based on the provided schema
+
+ Args:
+ query: The initial search query string
+ state: Injected application state containing the extraction schema
+ config: Runtime configuration for the search process
+
+ Returns:
+ str: Structured notes from the search results that are
+ relevant to the extraction schema in state.extraction_schema
+
+ Note:
+ The function uses concurrent execution for multiple search queries to improve
+ performance and combines results from various sources for comprehensive coverage.
"""
- async with aiohttp.ClientSession() as session:
- async with session.get(url) as response:
- content = await response.text()
+ configuration = Configuration.from_runnable_config(config)
+ # Generate search queries
+ raw_model = init_model(config)
+ structured_llm = raw_model.with_structured_output(Queries)
+
+ # Format system instructions
+ query_instructions = query_writer_instructions.format(initial_query=query)
+
+ # Generate queries
+ results = structured_llm.invoke([SystemMessage(content=query_instructions)]+[HumanMessage(content=f"Please generate {configuration.num_search_queries} search queries.")])
+
+ # Search client
+ tavily_async_client = AsyncTavilyClient()
+
+ # Web search
+ query_list = [query.search_query for query in results.queries]
+ search_tasks = []
+ for query in query_list:
+ search_tasks.append(
+ tavily_async_client.search(
+ query,
+ max_results=configuration.max_search_results,
+ include_raw_content=True,
+ topic="general"
+ )
+ )
+
+ # Execute all searches concurrently
+ search_docs = await asyncio.gather(*search_tasks)
+
+ # Deduplicate and format sources
+ source_str = deduplicate_and_format_sources(search_docs, max_tokens_per_source=1000, include_raw_content=True)
+
+ # Generate structured notes relevant to the extraction schema
p = _INFO_PROMPT.format(
info=json.dumps(state.extraction_schema, indent=2),
- url=url,
- content=content[:40_000],
+ content=source_str,
)
- raw_model = init_model(config)
result = await raw_model.ainvoke(p)
- return str(result.content)
+ return str(result.content)
\ No newline at end of file
diff --git a/src/enrichment_agent/utils.py b/src/enrichment_agent/utils.py
index 5dba58f..f7df2b5 100644
--- a/src/enrichment_agent/utils.py
+++ b/src/enrichment_agent/utils.py
@@ -32,3 +32,55 @@ def init_model(config: Optional[RunnableConfig] = None) -> BaseChatModel:
provider = None
model = fully_specified_name
return init_chat_model(model, model_provider=provider)
+
+def deduplicate_and_format_sources(search_response, max_tokens_per_source, include_raw_content=True):
+ """
+ Takes either a single search response or list of responses from Tavily API and formats them.
+ Limits the raw_content to approximately max_tokens_per_source.
+ include_raw_content specifies whether to include the raw_content from Tavily in the formatted string.
+
+ Args:
+ search_response: Either:
+ - A dict with a 'results' key containing a list of search results
+ - A list of dicts, each containing search results
+
+ Returns:
+ str: Formatted string with deduplicated sources
+ """
+ # Convert input to list of results
+ if isinstance(search_response, dict):
+ sources_list = search_response['results']
+ elif isinstance(search_response, list):
+ sources_list = []
+ for response in search_response:
+ if isinstance(response, dict) and 'results' in response:
+ sources_list.extend(response['results'])
+ else:
+ sources_list.extend(response)
+ else:
+ raise ValueError("Input must be either a dict with 'results' or a list of search results")
+
+ # Deduplicate by URL
+ unique_sources = {}
+ for source in sources_list:
+ if source['url'] not in unique_sources:
+ unique_sources[source['url']] = source
+
+ # Format output
+ formatted_text = "Sources:\n\n"
+ for i, source in enumerate(unique_sources.values(), 1):
+ formatted_text += f"Source {source['title']}:\n===\n"
+ formatted_text += f"URL: {source['url']}\n===\n"
+ formatted_text += f"Most relevant content from source: {source['content']}\n===\n"
+ if include_raw_content:
+ # Using rough estimate of 4 characters per token
+ char_limit = max_tokens_per_source * 4
+ # Handle None raw_content
+ raw_content = source.get('raw_content', '')
+ if raw_content is None:
+ raw_content = ''
+ if len(raw_content) > char_limit:
+ raw_content = raw_content[:char_limit] + "... [truncated]"
+ formatted_text += f"Full source content limited to {max_tokens_per_source} tokens: {raw_content}\n\n"
+
+ return formatted_text.strip()