From 98c35d9d6c65894457c4e7c489913cc8543464d3 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 14 Apr 2025 12:50:08 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 747530866 --- google/genai/_api_client.py | 52 ++++++++++--- .../client/test_client_initialization.py | 73 +++++++++++++++++++ .../genai/tests/client/test_http_options.py | 2 + google/genai/types.py | 14 ++++ 4 files changed, 132 insertions(+), 9 deletions(-) diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index f19366bb..6b49ea94 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -442,16 +442,50 @@ def __init__( else: if self._http_options.headers is not None: _append_library_version_headers(self._http_options.headers) - # Initialize the httpx client. - # Unlike requests, the httpx package does not automatically pull in the - # environment variables SSL_CERT_FILE or SSL_CERT_DIR. They need to be - # enabled explicitly. - ctx = ssl.create_default_context( - cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), - capath=os.environ.get('SSL_CERT_DIR'), + + client_args, async_client_args = self._ensure_ssl_ctx(self._http_options) + self._httpx_client = SyncHttpxClient(**client_args) + self._async_httpx_client = AsyncHttpxClient(**async_client_args) + + @staticmethod + def _ensure_ssl_ctx(options: HttpOptions) -> ( + Tuple[dict[str, Any], dict[str, Any]]): + """Ensures the SSL context is present in the client args. + + Create a default SSL context is not provided. + + Args: + options: The http options to update. + + Returns: + A tuple of sync and async httpx client options. + """ + + verify = 'verify' + args = options.client_args + async_args = options.async_client_args + ctx = ( + args.get(verify) if args else None + or async_args.get(verify) if async_args else None + ) + + if not ctx: + ctx = ssl.create_default_context( + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), + capath=os.environ.get('SSL_CERT_DIR'), + ) + + def _maybe_set(args: dict[str, Any], ctx: ssl.SSLContext) -> dict[str, Any]: + """Sets the SSL context in the client args if not set by making a copy.""" + if not args or not args.get(verify): + args = (args or {}).copy() + args[verify] = ctx + return args + + return ( + _maybe_set(args, ctx), + _maybe_set(async_args, ctx), ) - self._httpx_client = SyncHttpxClient(verify=ctx) - self._async_httpx_client = AsyncHttpxClient(verify=ctx) def _websocket_base_url(self): url_parts = urlparse(self._http_options.base_url) diff --git a/google/genai/tests/client/test_client_initialization.py b/google/genai/tests/client/test_client_initialization.py index c9ab5bfb..5bac08f1 100644 --- a/google/genai/tests/client/test_client_initialization.py +++ b/google/genai/tests/client/test_client_initialization.py @@ -16,10 +16,13 @@ """Tests for client initialization.""" +import certifi import google.auth from google.auth import credentials import logging +import os import pytest +import ssl from ... import _api_client as api_client from ... import _replay_api_client as replay_api_client @@ -587,3 +590,73 @@ def test_client_logs_to_logger_instance(monkeypatch, caplog): assert 'INFO' in caplog.text assert 'The user provided Vertex AI API key will take precedence' in caplog.text + +def test_client_ssl_context_implicit_initialization(): + client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( + api_client.HttpOptions()) + + assert client_args["verify"] + assert async_client_args["verify"] + assert isinstance(client_args["verify"], ssl.SSLContext) + assert isinstance(async_client_args["verify"], ssl.SSLContext) + +def test_client_ssl_context_explicit_initialization_same_args(): + ctx = ssl.create_default_context( + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), + capath=os.environ.get('SSL_CERT_DIR'), + ) + + options = api_client.HttpOptions( + client_args={"verify": ctx}, async_client_args={"verify": ctx}) + client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( + options) + + assert client_args["verify"] == ctx + assert async_client_args["verify"] == ctx + +def test_client_ssl_context_explicit_initialization_separate_args(): + ctx = ssl.create_default_context( + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), + capath=os.environ.get('SSL_CERT_DIR'), + ) + + async_ctx = ssl.create_default_context( + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), + capath=os.environ.get('SSL_CERT_DIR'), + ) + + options = api_client.HttpOptions( + client_args={"verify": ctx}, async_client_args={"verify": async_ctx}) + client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( + options) + + assert client_args["verify"] == ctx + assert async_client_args["verify"] == async_ctx + +def test_client_ssl_context_explicit_initialization_sync_args(): + ctx = ssl.create_default_context( + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), + capath=os.environ.get('SSL_CERT_DIR'), + ) + + options = api_client.HttpOptions( + client_args={"verify": ctx}) + client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( + options) + + assert client_args["verify"] == ctx + assert async_client_args["verify"] == ctx + +def test_client_ssl_context_explicit_initialization_async_args(): + ctx = ssl.create_default_context( + cafile=os.environ.get('SSL_CERT_FILE', certifi.where()), + capath=os.environ.get('SSL_CERT_DIR'), + ) + + options = api_client.HttpOptions( + async_client_args={"verify": ctx}) + client_args, async_client_args = api_client.BaseApiClient._ensure_ssl_ctx( + options) + + assert client_args["verify"] == ctx + assert async_client_args["verify"] == ctx diff --git a/google/genai/tests/client/test_http_options.py b/google/genai/tests/client/test_http_options.py index a943781a..d0f78d80 100644 --- a/google/genai/tests/client/test_http_options.py +++ b/google/genai/tests/client/test_http_options.py @@ -27,6 +27,8 @@ def test_patch_http_options_with_copies_all_fields(): api_version='v1', headers={'X-Custom-Header': 'custom_value'}, timeout=10000, + client_args={'http2': True}, + async_client_args={'http1': True}, ) options = types.HttpOptions() patched = _api_client._patch_http_options(options, patch_options) diff --git a/google/genai/types.py b/google/genai/types.py index 58cae243..aaaa4f54 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -820,6 +820,14 @@ class HttpOptions(_common.BaseModel): timeout: Optional[int] = Field( default=None, description="""Timeout for the request in milliseconds.""" ) + client_args: Optional[dict[str, Any]] = Field( + default=None, + description="""Args passed directly to the sync HTTP client.""", + ) + async_client_args: Optional[dict[str, Any]] = Field( + default=None, + description="""Args passed directly to the async HTTP client.""", + ) class HttpOptionsDict(TypedDict, total=False): @@ -837,6 +845,12 @@ class HttpOptionsDict(TypedDict, total=False): timeout: Optional[int] """Timeout for the request in milliseconds.""" + client_args: Optional[dict[str, Any]] + """Args passed directly to the sync HTTP client.""" + + async_client_args: Optional[dict[str, Any]] + """Args passed directly to the async HTTP client.""" + HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]