Skip to content

Dynamic authentication handling in MCPToolset from Adk session.state #1198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions contributing/samples/mcp_stdio_user_auth_passing_sample/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Sample: Passing User Token from Agent State to MCP via ContextToEnvMapperCallback

This sample demonstrates how to use the `context_to_env_mapper_callback` feature in ADK to pass a user token from the agent's session state to an MCP process (using stdio transport). This is useful when your MCP server (built by your organization) requires the same user token for internal API calls.

## How it works
- The agent is initialized with a `MCPToolset` using `StdioServerParameters`.
- The `context_to_env_mapper_callback` is set to a function that extracts the `user_token` from the agent's state and maps it to the `USER_TOKEN` environment variable.
- When the agent calls the MCP, the token is injected into the MCP process environment, allowing the MCP to use it for internal authentication.

## Directory Structure
```
contributing/samples/stdio_mcp_user_auth_passing_sample/
├── agent.py # Basic agent setup
├── main.py # Complete runnable example
└── README.md
```

## How to Run

### Option 1: Run the complete example
```bash
cd /home/sanjay-dev/Workspace/adk-python
python -m contributing.samples.stdio_mcp_user_auth_passing_sample.main
```

### Option 2: Use the agent in your own code
```python
from contributing.samples.stdio_mcp_user_auth_passing_sample.agent import create_agent
from google.adk.sessions import Session

agent = create_agent()
session = Session(
id="your_session_id",
app_name="your_app_name",
user_id="your_user_id"
)

# Set user token in session state
session.state['user_token'] = 'YOUR_ACTUAL_TOKEN_HERE'
session.state['api_endpoint'] = 'https://your-internal-api.com'

# Then use the agent in your workflow...
```

## Flow Diagram

```mermaid
graph TD
subgraph "User Application"
U[User]
end

subgraph "Agent Process"
A[Agent Instance<br/>per user-app-agentid]
S[Session State<br/>user_token, api_endpoint]
C[ContextToEnvMapperCallback]
end

subgraph "MCP Process"
M[MCP Server<br/>stdio transport]
E[Environment Variables<br/>USER_TOKEN, API_ENDPOINT]
API[Internal API Calls]
end

U -->|Sends request| A
A -->|Reads state| S
S -->|Extracts tokens| C
C -->|Maps to env vars| E
A -->|Spawns with env| M
M -->|Uses env vars| API
API -->|Response| M
M -->|Tool result| A
A -->|Response| U
```

## Context
- Each agent instance is initiated per user-app-agentid.
- The agent receives a user context (with token) and calls the MCP using stdio transport.
- The MCP, built by the same organization, uses the token for internal API calls.
- The ADK's context-to-env mapping feature makes this seamless.
Empty file.
60 changes: 60 additions & 0 deletions contributing/samples/mcp_stdio_user_auth_passing_sample/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Sample: Using ContextToEnvMapperCallback to pass user token from agent state to MCP via stdio transport.
"""

import os
import tempfile
from typing import Any
from typing import Dict

from google.adk.agents.llm_agent import LlmAgent
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset
from mcp import StdioServerParameters

_allowed_path = os.path.dirname(os.path.abspath(__file__))


def user_token_env_mapper(state: Dict[str, Any]) -> Dict[str, str]:
"""Extracts USER_TOKEN from agent state and maps to MCP env."""
env = {}
if "user_token" in state:
env["USER_TOKEN"] = state["user_token"]
if "api_endpoint" in state:
env["API_ENDPOINT"] = state["api_endpoint"]

print(f"Environment variables being passed to MCP: {env}")
return env


def create_agent() -> LlmAgent:
"""Create the agent with context to env mapper callback."""
# Create a temporary directory for the filesystem server
temp_dir = tempfile.mkdtemp()

return LlmAgent(
model="gemini-2.0-flash",
name="user_token_agent",
instruction=f"""
You are an agent that calls an internal MCP server which requires a user token for internal API calls.
The user token is available in your session state and must be passed to the MCP process as an environment variable.
Test directory: {temp_dir}
""",
tools=[
MCPToolset(
connection_params=StdioConnectionParams(
server_params=StdioServerParameters(
command="npx",
args=[
"-y", # Arguments for the command
"@modelcontextprotocol/server-filesystem",
_allowed_path,
],
),
timeout=5,
),
context_to_env_mapper_callback=user_token_env_mapper,
tool_filter=["read_file", "list_directory"],
)
],
)
95 changes: 95 additions & 0 deletions contributing/samples/mcp_stdio_user_auth_passing_sample/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Sample: Using ContextToEnvMapperCallback to pass user token from agent state to MCP via stdio transport.
"""

import asyncio

from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.readonly_context import ReadonlyContext
from google.adk.sessions import InMemorySessionService
from google.adk.sessions import Session

from .agent import create_agent


async def main():
"""Example of how to set up and run the agent with user token."""
print("=== STDIO MCP User Auth Passing Sample ===")
print()

# Create the agent
agent = create_agent()
print(f"✓ Created agent: {agent.name}")

# Create session service and session
session_service = InMemorySessionService()
session = Session(
id="sample_session",
app_name="stdio_mcp_user_auth_passing_sample",
user_id="sample_user",
)
print(f"✓ Created session: {session.id}")

# Set user token in session state
session.state["user_token"] = "sample_user_token_123"
session.state["api_endpoint"] = "https://internal-api.company.com"
print(f"✓ Set session state with user_token: {session.state['user_token']}")

# Create invocation context
invocation_context = InvocationContext(
invocation_id="sample_invocation",
agent=agent,
session=session,
session_service=session_service,
)

# Create readonly context
readonly_context = ReadonlyContext(invocation_context)
print(f"✓ Created readonly context")

print()
print("=== Demonstrating User Auth Token Passing to MCP ===")
print(
"Note: This sample shows how the callback extracts environment variables."
)
print("In a real scenario, these would be passed to an actual MCP server.")
print()

# Access the MCP toolset to demonstrate the callback
mcp_toolset = agent.tools[0]
mcp_session_manager = mcp_toolset._mcp_session_manager

# Extract environment variables using the callback (without connecting to MCP)
if mcp_session_manager._context_to_env_mapper_callback:
print("✓ Context-to-env mapper callback is configured")

# Simulate what happens during MCP session creation
env_vars = mcp_session_manager._extract_env_from_context(readonly_context)

print(f"✓ Extracted environment variables:")
for key, value in env_vars.items():
print(f" {key}={value}")
print()

print(
"✓ These environment variables would be injected into the MCP process"
)
print("✓ The MCP server can then use them for internal API calls")
else:
print("✗ No context-to-env mapper callback configured")

print()
print("=== Sample completed successfully! ===")
print()
print("Key points demonstrated:")
print("1. Session state holds user tokens and configuration")
print(
"2. Context-to-env mapper callback extracts these as environment"
" variables"
)
print("3. Environment variables would be passed to MCP server processes")
print("4. MCP servers can use these for authenticated API calls")


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"google-cloud-storage>=2.18.0, <3.0.0", # For GCS Artifact service
"google-genai>=1.17.0", # Google GenAI SDK
"graphviz>=0.20.2", # Graphviz for graph rendering
"mcp>=1.8.0;python_version>='3.10'", # For MCP Toolset
"mcp>=1.9.2;python_version>='3.10'", # For MCP Toolset
"opentelemetry-api>=1.31.0", # OpenTelemetry
"opentelemetry-exporter-gcp-trace>=1.9.0",
"opentelemetry-sdk>=1.31.0",
Expand Down
97 changes: 90 additions & 7 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
import sys
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import TextIO
from typing import Union
Expand Down Expand Up @@ -47,6 +49,10 @@
logger = logging.getLogger('google_adk.' + __name__)


# Type definition for environment variable transformation callback
ContextToEnvMapperCallback = Callable[[Dict[str, Any]], Dict[str, str]]


class StdioConnectionParams(BaseModel):
"""Parameters for the MCP Stdio connection.

Expand Down Expand Up @@ -179,6 +185,9 @@ def __init__(
StreamableHTTPConnectionParams,
],
errlog: TextIO = sys.stderr,
context_to_env_mapper_callback: Optional[
ContextToEnvMapperCallback
] = None,
):
"""Initializes the MCP session manager.

Expand All @@ -188,6 +197,10 @@ def __init__(
parameters but it's not configurable for now.
errlog: (Optional) TextIO stream for error logging. Use only for
initializing a local stdio MCP session.
context_to_env_mapper_callback: Optional callback function to extract environment
variables from session state. Takes a dictionary of session state and
returns a dictionary of environment variables to be injected into the
MCP connection.
"""
if isinstance(connection_params, StdioServerParameters):
# So far timeout is not configurable. Given MCP is still evolving, we
Expand All @@ -204,13 +217,20 @@ def __init__(
else:
self._connection_params = connection_params
self._errlog = errlog
self._context_to_env_mapper_callback = context_to_env_mapper_callback
# Each session manager maintains its own exit stack for proper cleanup
self._exit_stack: Optional[AsyncExitStack] = None
self._session: Optional[ClientSession] = None

async def create_session(self) -> ClientSession:
async def create_session(
self, readonly_context: Optional[Any] = None
) -> ClientSession:
"""Creates and initializes an MCP client session.

Args:
readonly_context: Optional readonly context containing session state
that can be used to extract environment variables.

Returns:
ClientSession: The initialized MCP client session.
"""
Expand All @@ -222,10 +242,17 @@ async def create_session(self) -> ClientSession:

try:
if isinstance(self._connection_params, StdioConnectionParams):
client = stdio_client(
server=self._connection_params.server_params,
errlog=self._errlog,
)
# Use original connection params as starting point
connection_params = self._connection_params.server_params

# Extract and inject environment variables for StdioServerParameters only
env_vars = self._extract_env_from_context(readonly_context)
connection_params = self._inject_env_vars(env_vars)
# So far timeout is not configurable. Given MCP is still evolving, we
# would expect stdio_client to evolve to accept timeout parameter like
# other client.
client = stdio_client(server=connection_params, errlog=self._errlog)

elif isinstance(self._connection_params, SseConnectionParams):
client = sse_client(
url=self._connection_params.url,
Expand Down Expand Up @@ -292,7 +319,63 @@ async def close(self):
self._exit_stack = None
self._session = None

def _extract_env_from_context(
self, readonly_context: Optional[Any]
) -> Dict[str, str]:
"""Extracts environment variables from readonly context using callback.

SseServerParams = SseConnectionParams
Args:
readonly_context: The readonly context containing state information.

StreamableHTTPServerParams = StreamableHTTPConnectionParams
Returns:
Dictionary of environment variables to inject.
"""
if not self._context_to_env_mapper_callback or not readonly_context:
return {}

try:
# Get state from readonly context if available
if hasattr(readonly_context, 'state') and readonly_context.state:
state_dict = dict(readonly_context.state)
return self._context_to_env_mapper_callback(state_dict)
else:
return {}
except Exception as e:
logger.warning(f'Context to env mapper callback failed: {e}')
return {}

def _inject_env_vars(self, env_vars: Dict[str, str]) -> StdioServerParameters:
"""Injects environment variables into StdioServerParameters.

Args:
env_vars: Dictionary of environment variables to inject.

Returns:
Updated StdioServerParameters with injected environment variables.
"""
if not env_vars:
return self._connection_params.server_params

# Get existing env vars from connection params
existing_env = (
getattr(self._connection_params.server_params, 'env', None) or {}
)

# Merge existing and new env vars (new ones take precedence)
merged_env = {**existing_env, **env_vars}

# Create new connection params with merged environment variables
return StdioServerParameters(
command=self._connection_params.server_params.command,
args=self._connection_params.server_params.args,
env=merged_env,
cwd=getattr(self._connection_params.server_params, 'cwd', None),
encoding=getattr(
self._connection_params.server_params, 'encoding', None
),
encoding_error_handler=getattr(
self._connection_params.server_params,
'encoding_error_handler',
None,
),
)
Loading