Skip to content

feat: add optional per-agent Vertex AI project and location configuration #1431

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/google/adk/models/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,13 @@ class Claude(BaseLlm):

Attributes:
model: The name of the Claude model.
project_id: Optional Google Cloud project ID. If not provided, uses GOOGLE_CLOUD_PROJECT environment variable.
location: Optional Google Cloud location. If not provided, uses GOOGLE_CLOUD_LOCATION environment variable.
"""

model: str = "claude-3-5-sonnet-v2@20241022"
project_id: Optional[str] = None
location: Optional[str] = None

@staticmethod
@override
Expand Down Expand Up @@ -250,16 +254,16 @@ async def generate_content_async(

@cached_property
def _anthropic_client(self) -> AnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
):
project = self.project_id or os.environ.get("GOOGLE_CLOUD_PROJECT")
location = self.location or os.environ.get("GOOGLE_CLOUD_LOCATION")

if not project or not location:
raise ValueError(
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
" Anthropic on Vertex."
)

return AnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
project_id=project,
region=location,
)
22 changes: 18 additions & 4 deletions src/google/adk/models/google_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
from typing import AsyncGenerator
from typing import cast
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

Expand Down Expand Up @@ -52,9 +53,14 @@ class Gemini(BaseLlm):

Attributes:
model: The name of the Gemini model.
project_id: Optional Google Cloud project ID. If not provided, uses GOOGLE_CLOUD_PROJECT environment variable.
location: Optional Google Cloud location. If not provided, uses GOOGLE_CLOUD_LOCATION environment variable.

"""

model: str = 'gemini-1.5-flash'
project_id: Optional[str] = None
location: Optional[str] = None

@staticmethod
@override
Expand Down Expand Up @@ -177,14 +183,22 @@ async def generate_content_async(

@cached_property
def api_client(self) -> Client:
"""Provides the api client.
"""Provides the api client with per-instance configuration support.

Returns:
The api client.
"""
return Client(
http_options=types.HttpOptions(headers=self._tracking_headers)
)
if self.project_id or self.location:
return Client(
vertexai=True,
project=self.project_id,
location=self.location,
http_options=types.HttpOptions(headers=self._tracking_headers),
)
else:
return Client(
http_options=types.HttpOptions(headers=self._tracking_headers)
)

@cached_property
def _api_backend(self) -> GoogleLLMVariant:
Expand Down
102 changes: 102 additions & 0 deletions tests/unittests/models/test_vertex_per_agent_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

from src.google.adk.models.anthropic_llm import Claude
from src.google.adk.models.google_llm import Gemini


def test_claude_custom_config():
claude = Claude(project_id="test-project-claude", location="us-central1")

assert claude.project_id == "test-project-claude"
assert claude.location == "us-central1"


def test_gemini_custom_config():
gemini = Gemini(project_id="test-project-gemini", location="europe-west1")

assert gemini.project_id == "test-project-gemini"
assert gemini.location == "europe-west1"


def test_claude_per_instance_configuration():
claude1 = Claude(project_id="project-1", location="us-central1")
claude2 = Claude(project_id="project-2", location="europe-west1")
claude3 = Claude()

assert claude1.project_id == "project-1"
assert claude1.location == "us-central1"

assert claude2.project_id == "project-2"
assert claude2.location == "europe-west1"

assert claude3.project_id is None
assert claude3.location is None


def test_gemini_per_instance_configuration():
gemini1 = Gemini(project_id="project-1", location="us-central1")
gemini2 = Gemini(project_id="project-2", location="europe-west1")
gemini3 = Gemini()

assert gemini1.project_id == "project-1"
assert gemini1.location == "us-central1"

assert gemini2.project_id == "project-2"
assert gemini2.location == "europe-west1"

assert gemini3.project_id is None
assert gemini3.location is None


def test_backward_compatibility():
claude = Claude()
gemini = Gemini()

assert claude.project_id is None
assert claude.location is None
assert gemini.project_id is None
assert gemini.location is None


@patch.dict(
"os.environ",
{
"GOOGLE_CLOUD_PROJECT": "env-project",
"GOOGLE_CLOUD_LOCATION": "env-location",
},
)
def test_claude_fallback_to_env_vars():
claude = Claude()

cache_key = f"{claude.project_id or 'default'}:{claude.location or 'default'}"
assert cache_key == "default:default"


def test_mixed_configuration():
claude_custom = Claude(project_id="custom-project", location="us-west1")
claude_default = Claude()

key_custom = (
f"{claude_custom.project_id or 'default'}:{claude_custom.location or 'default'}"
)
key_default = (
f"{claude_default.project_id or 'default'}:{claude_default.location or 'default'}"
)

assert key_custom != key_default
assert key_custom == "custom-project:us-west1"
assert key_default == "default:default"