diff --git a/pyproject.toml b/pyproject.toml index 526ec91..cfbb2d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.0.1" description = "An agent that populates and enriches custom schemas" authors = [ { name = "William Fu-Hinthorn", email = "13333726+hinthornw@users.noreply.github.com" }, + { name = "Lance Martin", email = "lance@langchain.dev" }, ] readme = "README.md" license = { text = "MIT" } @@ -16,6 +17,7 @@ dependencies = [ "langchain-fireworks>=0.1.7", "python-dotenv>=1.0.1", "langchain-community>=0.2.13", + "tavily-python", ] [project.optional-dependencies] diff --git a/src/enrichment_agent/configuration.py b/src/enrichment_agent/configuration.py index 2780be6..158b102 100644 --- a/src/enrichment_agent/configuration.py +++ b/src/enrichment_agent/configuration.py @@ -51,6 +51,13 @@ class Configuration: }, ) + num_search_queries: int = field( + default=2, + metadata={ + "description": "The number of search queries to generate for web search." + }, + ) + @classmethod def from_runnable_config( cls, config: Optional[RunnableConfig] = None diff --git a/src/enrichment_agent/graph.py b/src/enrichment_agent/graph.py index 64ef445..47d8312 100644 --- a/src/enrichment_agent/graph.py +++ b/src/enrichment_agent/graph.py @@ -15,7 +15,7 @@ from enrichment_agent import prompts from enrichment_agent.configuration import Configuration from enrichment_agent.state import InputState, OutputState, State -from enrichment_agent.tools import scrape_website, search +from enrichment_agent.tools import perform_web_research from enrichment_agent.utils import init_model @@ -51,7 +51,7 @@ async def call_agent_model( # Initialize the raw model with the provided configuration and bind the tools raw_model = init_model(config) - model = raw_model.bind_tools([scrape_website, search, info_tool], tool_choice="any") + model = raw_model.bind_tools([perform_web_research, info_tool], tool_choice="any") response = cast(AIMessage, await model.ainvoke(messages)) # Initialize info to None @@ -219,7 +219,7 @@ def route_after_checker( ) workflow.add_node(call_agent_model) workflow.add_node(reflect) -workflow.add_node("tools", ToolNode([search, scrape_website])) +workflow.add_node("tools", ToolNode([perform_web_research])) workflow.add_edge("__start__", "call_agent_model") workflow.add_conditional_edges("call_agent_model", route_after_agent) workflow.add_edge("tools", "call_agent_model") diff --git a/src/enrichment_agent/state.py b/src/enrichment_agent/state.py index ada0b8f..8f010fa 100644 --- a/src/enrichment_agent/state.py +++ b/src/enrichment_agent/state.py @@ -11,21 +11,15 @@ from langchain_core.messages import BaseMessage from langgraph.graph import add_messages - @dataclass(kw_only=True) class InputState: """Input state defines the interface between the graph and the user (external API).""" - - topic: str - "The topic for which the agent is tasked to gather information." + companies: list[str] + "List of company names to research." extraction_schema: dict[str, Any] "The json schema defines the information the agent is tasked with filling out." - info: Optional[dict[str, Any]] = field(default=None) - "The info state tracks the current extracted data for the given topic, conforming to the provided schema. This is primarily populated by the agent." - - @dataclass(kw_only=True) class State(InputState): """A graph's State defines three main things. diff --git a/src/enrichment_agent/tools.py b/src/enrichment_agent/tools.py index f756e31..5a4c033 100644 --- a/src/enrichment_agent/tools.py +++ b/src/enrichment_agent/tools.py @@ -5,70 +5,131 @@ Users can edit and extend these tools as needed. """ +import asyncio import json -from typing import Any, Optional, cast -import aiohttp -from langchain_community.tools.tavily_search import TavilySearchResults +from langchain_core.messages import HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig from langchain_core.tools import InjectedToolArg from langgraph.prebuilt import InjectedState -from typing_extensions import Annotated +from typing_extensions import Annotated, List + +from pydantic import BaseModel, Field from enrichment_agent.configuration import Configuration from enrichment_agent.state import State -from enrichment_agent.utils import init_model +from enrichment_agent.utils import init_model, deduplicate_and_format_sources +from tavily import AsyncTavilyClient -async def search( - query: str, *, config: Annotated[RunnableConfig, InjectedToolArg] -) -> Optional[list[dict[str, Any]]]: - """Query a search engine. +# Schemas +class SearchQuery(BaseModel): + search_query: str = Field(None, description="Query for web search.") - This function queries the web to fetch comprehensive, accurate, and trusted results. It's particularly useful - for answering questions about current events. Provide as much context in the query as needed to ensure high recall. - """ - configuration = Configuration.from_runnable_config(config) - wrapped = TavilySearchResults(max_results=configuration.max_search_results) - result = await wrapped.ainvoke({"query": query}) - return cast(list[dict[str, Any]], result) +class Queries(BaseModel): + queries: List[SearchQuery] = Field( + description="List of search queries.", + ) + +# Instructions +query_writer_instructions = """You are a search query generator tasked with creating diverse but related search queries based on an initial query. + +For the initial query: {initial_query} +Generate distinct search queries that: +1. Cover different aspects or angles of the topic +2. Are mutually exclusive to avoid redundant results +3. Help gather comprehensive information about the subject +4. Use different phrasings and keywords to maximize coverage +5. Maintain relevance to the original query intent -_INFO_PROMPT = """You are doing web research on behalf of a user. You are trying to find out this information: +Each query should focus on a unique aspect while staying connected to the main topic.""" - +_INFO_PROMPT = """You are doing web research on behalf of a user. You need to extract specific information based on this schema: + + {info} - + -You just scraped the following website: {url} +You have just scraped website content. Review the content below and take detailed notes that align with the extraction schema above. -Based on the website content below, jot down some notes about the website. +Focus only on information that matches the schema requirements. - + {content} -""" + +Please provide well structured notes that: +1. Map directly to the schema fields +2. Include only relevant information from the content +3. Maintain the original facts and data points +4. Note any missing schema fields that weren't found in the content""" -async def scrape_website( - url: str, - *, +async def perform_web_research( + query: str, + *, state: Annotated[State, InjectedState], - config: Annotated[RunnableConfig, InjectedToolArg], + config: Annotated[RunnableConfig, InjectedToolArg] ) -> str: - """Scrape and summarize content from a given URL. - - Returns: - str: A summary of the scraped content, tailored to the extraction schema. + """Execute a multi-step web search and information extraction process. + + This function performs the following steps: + 1. Generates multiple search queries based on the input query + 2. Executes concurrent web searches using the Tavily API + 3. Deduplicates and formats the search results + 4. Extracts structured information based on the provided schema + + Args: + query: The initial search query string + state: Injected application state containing the extraction schema + config: Runtime configuration for the search process + + Returns: + str: Structured notes from the search results that are + relevant to the extraction schema in state.extraction_schema + + Note: + The function uses concurrent execution for multiple search queries to improve + performance and combines results from various sources for comprehensive coverage. """ - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - content = await response.text() + configuration = Configuration.from_runnable_config(config) + # Generate search queries + raw_model = init_model(config) + structured_llm = raw_model.with_structured_output(Queries) + + # Format system instructions + query_instructions = query_writer_instructions.format(initial_query=query) + + # Generate queries + results = structured_llm.invoke([SystemMessage(content=query_instructions)]+[HumanMessage(content=f"Please generate {configuration.num_search_queries} search queries.")]) + + # Search client + tavily_async_client = AsyncTavilyClient() + + # Web search + query_list = [query.search_query for query in results.queries] + search_tasks = [] + for query in query_list: + search_tasks.append( + tavily_async_client.search( + query, + max_results=configuration.max_search_results, + include_raw_content=True, + topic="general" + ) + ) + + # Execute all searches concurrently + search_docs = await asyncio.gather(*search_tasks) + + # Deduplicate and format sources + source_str = deduplicate_and_format_sources(search_docs, max_tokens_per_source=1000, include_raw_content=True) + + # Generate structured notes relevant to the extraction schema p = _INFO_PROMPT.format( info=json.dumps(state.extraction_schema, indent=2), - url=url, - content=content[:40_000], + content=source_str, ) - raw_model = init_model(config) result = await raw_model.ainvoke(p) - return str(result.content) + return str(result.content) \ No newline at end of file diff --git a/src/enrichment_agent/utils.py b/src/enrichment_agent/utils.py index 5dba58f..f7df2b5 100644 --- a/src/enrichment_agent/utils.py +++ b/src/enrichment_agent/utils.py @@ -32,3 +32,55 @@ def init_model(config: Optional[RunnableConfig] = None) -> BaseChatModel: provider = None model = fully_specified_name return init_chat_model(model, model_provider=provider) + +def deduplicate_and_format_sources(search_response, max_tokens_per_source, include_raw_content=True): + """ + Takes either a single search response or list of responses from Tavily API and formats them. + Limits the raw_content to approximately max_tokens_per_source. + include_raw_content specifies whether to include the raw_content from Tavily in the formatted string. + + Args: + search_response: Either: + - A dict with a 'results' key containing a list of search results + - A list of dicts, each containing search results + + Returns: + str: Formatted string with deduplicated sources + """ + # Convert input to list of results + if isinstance(search_response, dict): + sources_list = search_response['results'] + elif isinstance(search_response, list): + sources_list = [] + for response in search_response: + if isinstance(response, dict) and 'results' in response: + sources_list.extend(response['results']) + else: + sources_list.extend(response) + else: + raise ValueError("Input must be either a dict with 'results' or a list of search results") + + # Deduplicate by URL + unique_sources = {} + for source in sources_list: + if source['url'] not in unique_sources: + unique_sources[source['url']] = source + + # Format output + formatted_text = "Sources:\n\n" + for i, source in enumerate(unique_sources.values(), 1): + formatted_text += f"Source {source['title']}:\n===\n" + formatted_text += f"URL: {source['url']}\n===\n" + formatted_text += f"Most relevant content from source: {source['content']}\n===\n" + if include_raw_content: + # Using rough estimate of 4 characters per token + char_limit = max_tokens_per_source * 4 + # Handle None raw_content + raw_content = source.get('raw_content', '') + if raw_content is None: + raw_content = '' + if len(raw_content) > char_limit: + raw_content = raw_content[:char_limit] + "... [truncated]" + formatted_text += f"Full source content limited to {max_tokens_per_source} tokens: {raw_content}\n\n" + + return formatted_text.strip()