Skip to content

Commit cb29748

Browse files
committed
Update configuration
1 parent 630b0ea commit cb29748

File tree

5 files changed

+22
-33
lines changed

5 files changed

+22
-33
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ dependencies = [
1515
"langchain>=0.2.14",
1616
"langchain-fireworks>=0.1.7",
1717
"python-dotenv>=1.0.1",
18-
"langchain-community>=0.2.17",
19-
"tavily-python>=0.4.0",
18+
"langchain-tavily>=0.1",
2019
]
2120

2221

@@ -64,4 +63,5 @@ convention = "google"
6463
[dependency-groups]
6564
dev = [
6665
"langgraph-cli[inmem]>=0.1.71",
66+
"pytest>=8.3.5",
6767
]

src/react_agent/configuration.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field, fields
6-
from typing import Annotated, Optional
6+
from typing import Annotated
77

8-
from langchain_core.runnables import RunnableConfig, ensure_config
8+
from langchain_core.runnables import ensure_config
9+
from langgraph.config import get_config
910

1011
from react_agent import prompts
1112

@@ -38,10 +39,12 @@ class Configuration:
3839
)
3940

4041
@classmethod
41-
def from_runnable_config(
42-
cls, config: Optional[RunnableConfig] = None
43-
) -> Configuration:
42+
def from_context(cls) -> Configuration:
4443
"""Create a Configuration instance from a RunnableConfig object."""
44+
try:
45+
config = get_config()
46+
except RuntimeError:
47+
config = None
4548
config = ensure_config(config)
4649
configurable = config.get("configurable") or {}
4750
_fields = {f.name for f in fields(cls) if f.init}

src/react_agent/graph.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
Works with a chat model with tool calling support.
44
"""
55

6-
from datetime import datetime, timezone
6+
from datetime import UTC, datetime
77
from typing import Dict, List, Literal, cast
88

99
from langchain_core.messages import AIMessage
10-
from langchain_core.runnables import RunnableConfig
1110
from langgraph.graph import StateGraph
1211
from langgraph.prebuilt import ToolNode
1312

@@ -19,9 +18,7 @@
1918
# Define the function that calls the model
2019

2120

22-
async def call_model(
23-
state: State, config: RunnableConfig
24-
) -> Dict[str, List[AIMessage]]:
21+
async def call_model(state: State) -> Dict[str, List[AIMessage]]:
2522
"""Call the LLM powering our "agent".
2623
2724
This function prepares the prompt, initializes the model, and processes the response.
@@ -33,21 +30,21 @@ async def call_model(
3330
Returns:
3431
dict: A dictionary containing the model's response message.
3532
"""
36-
configuration = Configuration.from_runnable_config(config)
33+
configuration = Configuration.from_context()
3734

3835
# Initialize the model with tool binding. Change the model or add more tools here.
3936
model = load_chat_model(configuration.model).bind_tools(TOOLS)
4037

4138
# Format the system prompt. Customize this to change the agent's behavior.
4239
system_message = configuration.system_prompt.format(
43-
system_time=datetime.now(tz=timezone.utc).isoformat()
40+
system_time=datetime.now(tz=UTC).isoformat()
4441
)
4542

4643
# Get the model's response
4744
response = cast(
4845
AIMessage,
4946
await model.ainvoke(
50-
[{"role": "system", "content": system_message}, *state.messages], config
47+
[{"role": "system", "content": system_message}, *state.messages]
5148
),
5249
)
5350

@@ -115,9 +112,4 @@ def route_model_output(state: State) -> Literal["__end__", "tools"]:
115112
builder.add_edge("tools", "call_model")
116113

117114
# Compile the builder into an executable graph
118-
# You can customize this by adding interrupt points for state updates
119-
graph = builder.compile(
120-
interrupt_before=[], # Add node names here to update state before they're called
121-
interrupt_after=[], # Add node names here to update state after they're called
122-
)
123-
graph.name = "ReAct Agent" # This customizes the name in LangSmith
115+
graph = builder.compile(name="ReAct Agent")

src/react_agent/tools.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,21 @@
88

99
from typing import Any, Callable, List, Optional, cast
1010

11-
from langchain_community.tools.tavily_search import TavilySearchResults
12-
from langchain_core.runnables import RunnableConfig
13-
from langchain_core.tools import InjectedToolArg
14-
from typing_extensions import Annotated
11+
from langchain_tavily import TavilySearch # type: ignore[import-not-found]
1512

1613
from react_agent.configuration import Configuration
1714

1815

19-
async def search(
20-
query: str, *, config: Annotated[RunnableConfig, InjectedToolArg]
21-
) -> Optional[list[dict[str, Any]]]:
16+
async def search(query: str) -> Optional[dict[str, Any]]:
2217
"""Search for general web results.
2318
2419
This function performs a search using the Tavily search engine, which is designed
2520
to provide comprehensive, accurate, and trusted results. It's particularly useful
2621
for answering questions about current events.
2722
"""
28-
configuration = Configuration.from_runnable_config(config)
29-
wrapped = TavilySearchResults(max_results=configuration.max_search_results)
30-
result = await wrapped.ainvoke({"query": query})
31-
return cast(list[dict[str, Any]], result)
23+
configuration = Configuration.from_context()
24+
wrapped = TavilySearch(max_results=configuration.max_search_results)
25+
return cast(dict[str, Any], await wrapped.ainvoke({"query": query}))
3226

3327

3428
TOOLS: List[Callable[..., Any]] = [search]

tests/unit_tests/test_configuration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33

44
def test_configuration_empty() -> None:
5-
Configuration.from_runnable_config({})
5+
Configuration.from_context()

0 commit comments

Comments
 (0)