From fb14b3b0a6b0469695c0c7174acb3db5c44afac9 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 27 Oct 2025 13:20:44 -0500 Subject: [PATCH 01/12] feat: enable grpc config on agent instantiation Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 52 ++++++++ tests/workflow/test_grpc_config.py | 188 +++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 tests/workflow/test_grpc_config.py diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index 4b4f12cf..d7cd1791 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -46,6 +46,14 @@ class WorkflowApp(BaseModel, SignalHandlingMixin): default=300, description="Default timeout duration in seconds for workflow tasks.", ) + grpc_max_send_message_length: Optional[int] = Field( + default=None, + description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).", + ) + grpc_max_receive_message_length: Optional[int] = Field( + default=None, + description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).", + ) # Initialized in model_post_init wf_runtime: Optional[WorkflowRuntime] = Field( @@ -72,6 +80,9 @@ def model_post_init(self, __context: Any) -> None: """ Initialize the Dapr workflow runtime and register tasks & workflows. """ + if self.grpc_max_send_message_length or self.grpc_max_receive_message_length: + self._configure_grpc_channel_options() + # Initialize LLM first if self.llm is None: self.llm = get_default_llm() @@ -92,6 +103,47 @@ def model_post_init(self, __context: Any) -> None: super().model_post_init(__context) + def _configure_grpc_channel_options(self) -> None: + """ + Configure gRPC channel options before workflow runtime initialization. + This patches the durabletask internal channel factory to support custom message size limits. + + This is particularly useful for AI-powered workflows that may need to handle large payloads + such as images, which can exceed the default 4MB gRPC message size limit. + """ + try: + import grpc + from durabletask.internal import shared + + # Store the original get_grpc_channel function + original_get_grpc_channel = shared.get_grpc_channel + + # Create custom options list + options = [] + if self.grpc_max_send_message_length: + options.append(('grpc.max_send_message_length', self.grpc_max_send_message_length)) + logger.debug(f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)") + if self.grpc_max_receive_message_length: + options.append(('grpc.max_receive_message_length', self.grpc_max_receive_message_length)) + logger.debug(f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)") + + # Patch the function to include our custom options + def get_grpc_channel_with_options(address: str): + """Custom gRPC channel factory with configured message size limits.""" + return grpc.insecure_channel(address, options=options) + + # Replace the function + shared.get_grpc_channel = get_grpc_channel_with_options + + logger.debug("Successfully patched durabletask gRPC channel factory with custom options") + + except ImportError as e: + logger.error(f"Failed to import required modules for gRPC configuration: {e}") + raise + except Exception as e: + logger.error(f"Failed to configure gRPC channel options: {e}") + raise + def graceful_shutdown(self) -> None: """ Perform graceful shutdown operations for the WorkflowApp. diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py new file mode 100644 index 00000000..b0b79f38 --- /dev/null +++ b/tests/workflow/test_grpc_config.py @@ -0,0 +1,188 @@ +"""Tests for gRPC configuration in WorkflowApp.""" +import pytest +from unittest.mock import MagicMock, patch, call +from dapr_agents.workflow.base import WorkflowApp + + +@pytest.fixture +def mock_workflow_dependencies(): + """Mock all the dependencies needed for WorkflowApp initialization.""" + with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, \ + patch("dapr_agents.workflow.base.DaprWorkflowClient") as mock_client, \ + patch("dapr_agents.workflow.base.get_default_llm") as mock_llm, \ + patch.object(WorkflowApp, "start_runtime") as mock_start, \ + patch.object(WorkflowApp, "setup_signal_handlers") as mock_handlers: + + mock_runtime_instance = MagicMock() + mock_runtime.return_value = mock_runtime_instance + + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + + mock_llm_instance = MagicMock() + mock_llm.return_value = mock_llm_instance + + yield { + "runtime": mock_runtime, + "runtime_instance": mock_runtime_instance, + "client": mock_client, + "client_instance": mock_client_instance, + "llm": mock_llm, + "llm_instance": mock_llm_instance, + "start_runtime": mock_start, + "signal_handlers": mock_handlers, + } + + +def test_workflow_app_without_grpc_config(mock_workflow_dependencies): + """Test that WorkflowApp initializes without gRPC configuration.""" + # Create WorkflowApp without gRPC config + app = WorkflowApp() + + # Verify the app was created + assert app is not None + assert app.grpc_max_send_message_length is None + assert app.grpc_max_receive_message_length is None + + # Verify runtime and client were initialized + assert app.wf_runtime is not None + assert app.wf_client is not None + + +def test_workflow_app_with_grpc_config(mock_workflow_dependencies): + """Test that WorkflowApp initializes with gRPC configuration.""" + # Mock the grpc module and durabletask shared module + mock_grpc = MagicMock() + mock_shared = MagicMock() + mock_channel = MagicMock() + + # Set up the mock channel + mock_grpc.insecure_channel.return_value = mock_channel + mock_shared.get_grpc_channel = MagicMock() + + with patch.dict('sys.modules', { + 'grpc': mock_grpc, + 'durabletask.internal.shared': mock_shared, + }): + # Create WorkflowApp with gRPC config (16MB) + app = WorkflowApp( + grpc_max_send_message_length=16 * 1024 * 1024, # 16MB + grpc_max_receive_message_length=16 * 1024 * 1024, # 16MB + ) + + # Verify the configuration was set + assert app.grpc_max_send_message_length == 16 * 1024 * 1024 + assert app.grpc_max_receive_message_length == 16 * 1024 * 1024 + + # Verify runtime and client were initialized + assert app.wf_runtime is not None + assert app.wf_client is not None + + +def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies): + """Test that _configure_grpc_channel_options is called when gRPC config is provided.""" + with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + # Create WorkflowApp with gRPC config + app = WorkflowApp( + grpc_max_send_message_length=8 * 1024 * 1024, # 8MB + ) + + # Verify the configuration method was called + mock_configure.assert_called_once() + + # Verify the configuration was set + assert app.grpc_max_send_message_length == 8 * 1024 * 1024 + + +def test_configure_grpc_channel_options_not_called_without_config(mock_workflow_dependencies): + """Test that _configure_grpc_channel_options is not called without gRPC config.""" + with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + # Create WorkflowApp without gRPC config + app = WorkflowApp() + + # Verify the configuration method was NOT called + mock_configure.assert_not_called() + + +def test_grpc_channel_patching(): + """Test that the gRPC channel factory is properly patched with custom options.""" + # Mock the grpc module and durabletask shared module + mock_grpc = MagicMock() + mock_shared = MagicMock() + mock_channel = MagicMock() + + # Set up the mock channel + mock_grpc.insecure_channel.return_value = mock_channel + original_get_grpc_channel = MagicMock() + mock_shared.get_grpc_channel = original_get_grpc_channel + + with patch.dict('sys.modules', { + 'grpc': mock_grpc, + 'durabletask.internal.shared': mock_shared, + }), patch("dapr_agents.workflow.base.WorkflowRuntime"), \ + patch("dapr_agents.workflow.base.DaprWorkflowClient"), \ + patch("dapr_agents.workflow.base.get_default_llm"), \ + patch.object(WorkflowApp, "start_runtime"), \ + patch.object(WorkflowApp, "setup_signal_handlers"): + + # Create WorkflowApp with gRPC config + max_send = 10 * 1024 * 1024 # 10MB + max_recv = 12 * 1024 * 1024 # 12MB + + app = WorkflowApp( + grpc_max_send_message_length=max_send, + grpc_max_receive_message_length=max_recv, + ) + + # Verify the shared.get_grpc_channel was replaced + assert mock_shared.get_grpc_channel != original_get_grpc_channel + + # Call the patched function + test_address = "localhost:50001" + mock_shared.get_grpc_channel(test_address) + + # Verify insecure_channel was called with correct options + mock_grpc.insecure_channel.assert_called_once() + call_args = mock_grpc.insecure_channel.call_args + + # Check that the address was passed + assert call_args[0][0] == test_address + + # Check that options were passed + assert 'options' in call_args[1] + options = call_args[1]['options'] + + # Verify options contain our custom message size limits + assert ('grpc.max_send_message_length', max_send) in options + assert ('grpc.max_receive_message_length', max_recv) in options + + +def test_grpc_config_with_only_send_limit(mock_workflow_dependencies): + """Test gRPC configuration with only send limit set.""" + with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + app = WorkflowApp( + grpc_max_send_message_length=20 * 1024 * 1024, # 20MB + ) + + # Verify configuration was called + mock_configure.assert_called_once() + + # Verify only send limit was set + assert app.grpc_max_send_message_length == 20 * 1024 * 1024 + assert app.grpc_max_receive_message_length is None + + +def test_grpc_config_with_only_receive_limit(mock_workflow_dependencies): + """Test gRPC configuration with only receive limit set.""" + with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + app = WorkflowApp( + grpc_max_receive_message_length=24 * 1024 * 1024, # 24MB + ) + + # Verify configuration was called + mock_configure.assert_called_once() + + # Verify only receive limit was set + assert app.grpc_max_send_message_length is None + assert app.grpc_max_receive_message_length == 24 * 1024 * 1024 + From dd2b375c089850b27ad11007246e1f24b17d8928 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 27 Oct 2025 13:24:30 -0500 Subject: [PATCH 02/12] style: tox -e ruff Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 43 +++++++---- tests/workflow/test_grpc_config.py | 117 ++++++++++++++++------------- 2 files changed, 92 insertions(+), 68 deletions(-) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index d7cd1791..1d14e02c 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -82,7 +82,7 @@ def model_post_init(self, __context: Any) -> None: """ if self.grpc_max_send_message_length or self.grpc_max_receive_message_length: self._configure_grpc_channel_options() - + # Initialize LLM first if self.llm is None: self.llm = get_default_llm() @@ -107,38 +107,53 @@ def _configure_grpc_channel_options(self) -> None: """ Configure gRPC channel options before workflow runtime initialization. This patches the durabletask internal channel factory to support custom message size limits. - + This is particularly useful for AI-powered workflows that may need to handle large payloads such as images, which can exceed the default 4MB gRPC message size limit. """ try: import grpc from durabletask.internal import shared - + # Store the original get_grpc_channel function original_get_grpc_channel = shared.get_grpc_channel - + # Create custom options list options = [] if self.grpc_max_send_message_length: - options.append(('grpc.max_send_message_length', self.grpc_max_send_message_length)) - logger.debug(f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)") + options.append( + ("grpc.max_send_message_length", self.grpc_max_send_message_length) + ) + logger.debug( + f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)" + ) if self.grpc_max_receive_message_length: - options.append(('grpc.max_receive_message_length', self.grpc_max_receive_message_length)) - logger.debug(f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)") - + options.append( + ( + "grpc.max_receive_message_length", + self.grpc_max_receive_message_length, + ) + ) + logger.debug( + f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)" + ) + # Patch the function to include our custom options def get_grpc_channel_with_options(address: str): """Custom gRPC channel factory with configured message size limits.""" return grpc.insecure_channel(address, options=options) - + # Replace the function shared.get_grpc_channel = get_grpc_channel_with_options - - logger.debug("Successfully patched durabletask gRPC channel factory with custom options") - + + logger.debug( + "Successfully patched durabletask gRPC channel factory with custom options" + ) + except ImportError as e: - logger.error(f"Failed to import required modules for gRPC configuration: {e}") + logger.error( + f"Failed to import required modules for gRPC configuration: {e}" + ) raise except Exception as e: logger.error(f"Failed to configure gRPC channel options: {e}") diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index b0b79f38..c7eb2b81 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -7,21 +7,24 @@ @pytest.fixture def mock_workflow_dependencies(): """Mock all the dependencies needed for WorkflowApp initialization.""" - with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, \ - patch("dapr_agents.workflow.base.DaprWorkflowClient") as mock_client, \ - patch("dapr_agents.workflow.base.get_default_llm") as mock_llm, \ - patch.object(WorkflowApp, "start_runtime") as mock_start, \ - patch.object(WorkflowApp, "setup_signal_handlers") as mock_handlers: - + with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, patch( + "dapr_agents.workflow.base.DaprWorkflowClient" + ) as mock_client, patch( + "dapr_agents.workflow.base.get_default_llm" + ) as mock_llm, patch.object( + WorkflowApp, "start_runtime" + ) as mock_start, patch.object( + WorkflowApp, "setup_signal_handlers" + ) as mock_handlers: mock_runtime_instance = MagicMock() mock_runtime.return_value = mock_runtime_instance - + mock_client_instance = MagicMock() mock_client.return_value = mock_client_instance - + mock_llm_instance = MagicMock() mock_llm.return_value = mock_llm_instance - + yield { "runtime": mock_runtime, "runtime_instance": mock_runtime_instance, @@ -38,12 +41,12 @@ def test_workflow_app_without_grpc_config(mock_workflow_dependencies): """Test that WorkflowApp initializes without gRPC configuration.""" # Create WorkflowApp without gRPC config app = WorkflowApp() - + # Verify the app was created assert app is not None assert app.grpc_max_send_message_length is None assert app.grpc_max_receive_message_length is None - + # Verify runtime and client were initialized assert app.wf_runtime is not None assert app.wf_client is not None @@ -55,25 +58,28 @@ def test_workflow_app_with_grpc_config(mock_workflow_dependencies): mock_grpc = MagicMock() mock_shared = MagicMock() mock_channel = MagicMock() - + # Set up the mock channel mock_grpc.insecure_channel.return_value = mock_channel mock_shared.get_grpc_channel = MagicMock() - - with patch.dict('sys.modules', { - 'grpc': mock_grpc, - 'durabletask.internal.shared': mock_shared, - }): + + with patch.dict( + "sys.modules", + { + "grpc": mock_grpc, + "durabletask.internal.shared": mock_shared, + }, + ): # Create WorkflowApp with gRPC config (16MB) app = WorkflowApp( grpc_max_send_message_length=16 * 1024 * 1024, # 16MB grpc_max_receive_message_length=16 * 1024 * 1024, # 16MB ) - + # Verify the configuration was set assert app.grpc_max_send_message_length == 16 * 1024 * 1024 assert app.grpc_max_receive_message_length == 16 * 1024 * 1024 - + # Verify runtime and client were initialized assert app.wf_runtime is not None assert app.wf_client is not None @@ -81,25 +87,27 @@ def test_workflow_app_with_grpc_config(mock_workflow_dependencies): def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies): """Test that _configure_grpc_channel_options is called when gRPC config is provided.""" - with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: # Create WorkflowApp with gRPC config app = WorkflowApp( grpc_max_send_message_length=8 * 1024 * 1024, # 8MB ) - + # Verify the configuration method was called mock_configure.assert_called_once() - + # Verify the configuration was set assert app.grpc_max_send_message_length == 8 * 1024 * 1024 -def test_configure_grpc_channel_options_not_called_without_config(mock_workflow_dependencies): +def test_configure_grpc_channel_options_not_called_without_config( + mock_workflow_dependencies, +): """Test that _configure_grpc_channel_options is not called without gRPC config.""" - with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: # Create WorkflowApp without gRPC config app = WorkflowApp() - + # Verify the configuration method was NOT called mock_configure.assert_not_called() @@ -110,63 +118,65 @@ def test_grpc_channel_patching(): mock_grpc = MagicMock() mock_shared = MagicMock() mock_channel = MagicMock() - + # Set up the mock channel mock_grpc.insecure_channel.return_value = mock_channel original_get_grpc_channel = MagicMock() mock_shared.get_grpc_channel = original_get_grpc_channel - - with patch.dict('sys.modules', { - 'grpc': mock_grpc, - 'durabletask.internal.shared': mock_shared, - }), patch("dapr_agents.workflow.base.WorkflowRuntime"), \ - patch("dapr_agents.workflow.base.DaprWorkflowClient"), \ - patch("dapr_agents.workflow.base.get_default_llm"), \ - patch.object(WorkflowApp, "start_runtime"), \ - patch.object(WorkflowApp, "setup_signal_handlers"): - + + with patch.dict( + "sys.modules", + { + "grpc": mock_grpc, + "durabletask.internal.shared": mock_shared, + }, + ), patch("dapr_agents.workflow.base.WorkflowRuntime"), patch( + "dapr_agents.workflow.base.DaprWorkflowClient" + ), patch("dapr_agents.workflow.base.get_default_llm"), patch.object( + WorkflowApp, "start_runtime" + ), patch.object(WorkflowApp, "setup_signal_handlers"): # Create WorkflowApp with gRPC config max_send = 10 * 1024 * 1024 # 10MB max_recv = 12 * 1024 * 1024 # 12MB - + app = WorkflowApp( grpc_max_send_message_length=max_send, grpc_max_receive_message_length=max_recv, ) - + # Verify the shared.get_grpc_channel was replaced assert mock_shared.get_grpc_channel != original_get_grpc_channel - + # Call the patched function test_address = "localhost:50001" mock_shared.get_grpc_channel(test_address) - + # Verify insecure_channel was called with correct options mock_grpc.insecure_channel.assert_called_once() call_args = mock_grpc.insecure_channel.call_args - + # Check that the address was passed assert call_args[0][0] == test_address - + # Check that options were passed - assert 'options' in call_args[1] - options = call_args[1]['options'] - + assert "options" in call_args[1] + options = call_args[1]["options"] + # Verify options contain our custom message size limits - assert ('grpc.max_send_message_length', max_send) in options - assert ('grpc.max_receive_message_length', max_recv) in options + assert ("grpc.max_send_message_length", max_send) in options + assert ("grpc.max_receive_message_length", max_recv) in options def test_grpc_config_with_only_send_limit(mock_workflow_dependencies): """Test gRPC configuration with only send limit set.""" - with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: app = WorkflowApp( grpc_max_send_message_length=20 * 1024 * 1024, # 20MB ) - + # Verify configuration was called mock_configure.assert_called_once() - + # Verify only send limit was set assert app.grpc_max_send_message_length == 20 * 1024 * 1024 assert app.grpc_max_receive_message_length is None @@ -174,15 +184,14 @@ def test_grpc_config_with_only_send_limit(mock_workflow_dependencies): def test_grpc_config_with_only_receive_limit(mock_workflow_dependencies): """Test gRPC configuration with only receive limit set.""" - with patch.object(WorkflowApp, '_configure_grpc_channel_options') as mock_configure: + with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: app = WorkflowApp( grpc_max_receive_message_length=24 * 1024 * 1024, # 24MB ) - + # Verify configuration was called mock_configure.assert_called_once() - + # Verify only receive limit was set assert app.grpc_max_send_message_length is None assert app.grpc_max_receive_message_length == 24 * 1024 * 1024 - From d16947bde0bec578dbbb66d2a726d9cd61a9d6c5 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 27 Oct 2025 13:32:35 -0500 Subject: [PATCH 03/12] style: tox -e flake8 Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 3 --- tests/workflow/test_grpc_config.py | 9 +++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index 1d14e02c..09062403 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -115,9 +115,6 @@ def _configure_grpc_channel_options(self) -> None: import grpc from durabletask.internal import shared - # Store the original get_grpc_channel function - original_get_grpc_channel = shared.get_grpc_channel - # Create custom options list options = [] if self.grpc_max_send_message_length: diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index c7eb2b81..8c42a184 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -89,16 +89,13 @@ def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies): """Test that _configure_grpc_channel_options is called when gRPC config is provided.""" with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: # Create WorkflowApp with gRPC config - app = WorkflowApp( + WorkflowApp( grpc_max_send_message_length=8 * 1024 * 1024, # 8MB ) # Verify the configuration method was called mock_configure.assert_called_once() - # Verify the configuration was set - assert app.grpc_max_send_message_length == 8 * 1024 * 1024 - def test_configure_grpc_channel_options_not_called_without_config( mock_workflow_dependencies, @@ -106,7 +103,7 @@ def test_configure_grpc_channel_options_not_called_without_config( """Test that _configure_grpc_channel_options is not called without gRPC config.""" with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure: # Create WorkflowApp without gRPC config - app = WorkflowApp() + WorkflowApp() # Verify the configuration method was NOT called mock_configure.assert_not_called() @@ -139,7 +136,7 @@ def test_grpc_channel_patching(): max_send = 10 * 1024 * 1024 # 10MB max_recv = 12 * 1024 * 1024 # 12MB - app = WorkflowApp( + WorkflowApp( grpc_max_send_message_length=max_send, grpc_max_receive_message_length=max_recv, ) From 71d1a42e994c2491d3a6b5c95ba2a96556b71b78 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Mon, 27 Oct 2025 16:14:19 -0500 Subject: [PATCH 04/12] fix: correct test assertion Signed-off-by: Samantha Coyle --- tests/workflow/test_grpc_config.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index 8c42a184..8cad0b18 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -118,7 +118,9 @@ def test_grpc_channel_patching(): # Set up the mock channel mock_grpc.insecure_channel.return_value = mock_channel - original_get_grpc_channel = MagicMock() + + # Keep original reference + original_get_grpc_channel = lambda *_, **__: "original" mock_shared.get_grpc_channel = original_get_grpc_channel with patch.dict( @@ -136,12 +138,13 @@ def test_grpc_channel_patching(): max_send = 10 * 1024 * 1024 # 10MB max_recv = 12 * 1024 * 1024 # 12MB - WorkflowApp( + app = WorkflowApp( grpc_max_send_message_length=max_send, grpc_max_receive_message_length=max_recv, ) - # Verify the shared.get_grpc_channel was replaced + # Confirm get_grpc_channel was overridden + assert callable(mock_shared.get_grpc_channel) assert mock_shared.get_grpc_channel != original_get_grpc_channel # Call the patched function @@ -156,8 +159,8 @@ def test_grpc_channel_patching(): assert call_args[0][0] == test_address # Check that options were passed - assert "options" in call_args[1] - options = call_args[1]["options"] + assert "options" in call_args.kwargs + options = call_args.kwargs["options"] # Verify options contain our custom message size limits assert ("grpc.max_send_message_length", max_send) in options From 80030518e8bd023fed8b776dc30a1e162b71c2e6 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 29 Oct 2025 14:47:11 -0500 Subject: [PATCH 05/12] fix: add validations and correct func params Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 58 ++++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 5 deletions(-) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index 09062403..a3000b43 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -7,7 +7,7 @@ import sys import uuid from datetime import datetime, timezone -from typing import Any, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, Sequence from dapr.ext.workflow import ( DaprWorkflowClient, @@ -16,7 +16,7 @@ ) from dapr.ext.workflow.workflow_state import WorkflowState from durabletask import task as dtask -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from dapr_agents.agents.base import ChatClientBase from dapr_agents.llm.utils.defaults import get_default_llm @@ -76,6 +76,22 @@ class WorkflowApp(BaseModel, SignalHandlingMixin): model_config = ConfigDict(arbitrary_types_allowed=True) + @model_validator(mode="before") + def validate_grpc_chanell_options(cls, values: Any): + # Ensure we only operate on dict inputs and always return the original values + if not isinstance(values, dict): + return values + + if values.get("grpc_max_send_message_length") is not None: + if values["grpc_max_send_message_length"] < 0: + raise ValueError("grpc_max_send_message_length must be greater than 0") + + if values.get("grpc_max_receive_message_length") is not None: + if values["grpc_max_receive_message_length"] < 0: + raise ValueError("grpc_max_receive_message_length must be greater than 0") + + return values + def model_post_init(self, __context: Any) -> None: """ Initialize the Dapr workflow runtime and register tasks & workflows. @@ -136,9 +152,41 @@ def _configure_grpc_channel_options(self) -> None: ) # Patch the function to include our custom options - def get_grpc_channel_with_options(address: str): - """Custom gRPC channel factory with configured message size limits.""" - return grpc.insecure_channel(address, options=options) + def get_grpc_channel_with_options(host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None): + # This is a copy of the original get_grpc_channel function in durabletask.internal.shared at + # https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19 + # but with my option overrides applied above. + if host_address is None: + host_address = shared.get_default_host_address() + + for protocol in getattr(shared, "SECURE_PROTOCOLS", []): + if host_address.lower().startswith(protocol): + secure_channel = True + # remove the protocol from the host name + host_address = host_address[len(protocol):] + break + + for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): + if host_address.lower().startswith(protocol): + secure_channel = False + # remove the protocol from the host name + host_address = host_address[len(protocol):] + break + + # Create the base channel + if secure_channel: + credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel(host_address, credentials, options=options) + else: + channel = grpc.insecure_channel(host_address, options=options) + + # Apply interceptors ONLY if they exist + if interceptors: + channel = grpc.intercept_channel(channel, *interceptors) + + return channel # Replace the function shared.get_grpc_channel = get_grpc_channel_with_options From 1ada4cd6c67e6330256486be27acf016f8e26994 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 29 Oct 2025 14:49:37 -0500 Subject: [PATCH 06/12] style: lint Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index a3000b43..8ba1159a 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -85,10 +85,12 @@ def validate_grpc_chanell_options(cls, values: Any): if values.get("grpc_max_send_message_length") is not None: if values["grpc_max_send_message_length"] < 0: raise ValueError("grpc_max_send_message_length must be greater than 0") - + if values.get("grpc_max_receive_message_length") is not None: if values["grpc_max_receive_message_length"] < 0: - raise ValueError("grpc_max_receive_message_length must be greater than 0") + raise ValueError( + "grpc_max_receive_message_length must be greater than 0" + ) return values @@ -152,9 +154,11 @@ def _configure_grpc_channel_options(self) -> None: ) # Patch the function to include our custom options - def get_grpc_channel_with_options(host_address: Optional[str], - secure_channel: bool = False, - interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None): + def get_grpc_channel_with_options( + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None, + ): # This is a copy of the original get_grpc_channel function in durabletask.internal.shared at # https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19 # but with my option overrides applied above. @@ -165,20 +169,22 @@ def get_grpc_channel_with_options(host_address: Optional[str], if host_address.lower().startswith(protocol): secure_channel = True # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): if host_address.lower().startswith(protocol): secure_channel = False # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break # Create the base channel if secure_channel: credentials = grpc.ssl_channel_credentials() - channel = grpc.secure_channel(host_address, credentials, options=options) + channel = grpc.secure_channel( + host_address, credentials, options=options + ) else: channel = grpc.insecure_channel(host_address, options=options) From a84b2c00480aec70fdc8e9b6eb7219e1dddabae3 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 29 Oct 2025 15:15:25 -0500 Subject: [PATCH 07/12] fix: tox -e flake8 Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 5 ++--- tests/workflow/test_grpc_config.py | 5 +++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index 8ba1159a..be081755 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -78,7 +78,6 @@ class WorkflowApp(BaseModel, SignalHandlingMixin): @model_validator(mode="before") def validate_grpc_chanell_options(cls, values: Any): - # Ensure we only operate on dict inputs and always return the original values if not isinstance(values, dict): return values @@ -169,14 +168,14 @@ def get_grpc_channel_with_options( if host_address.lower().startswith(protocol): secure_channel = True # remove the protocol from the host name - host_address = host_address[len(protocol) :] + host_address = host_address[len(protocol):] break for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): if host_address.lower().startswith(protocol): secure_channel = False # remove the protocol from the host name - host_address = host_address[len(protocol) :] + host_address = host_address[len(protocol):] break # Create the base channel diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index 8cad0b18..23b29123 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -120,7 +120,8 @@ def test_grpc_channel_patching(): mock_grpc.insecure_channel.return_value = mock_channel # Keep original reference - original_get_grpc_channel = lambda *_, **__: "original" + def original_get_grpc_channel(*_, **__): + return "original" mock_shared.get_grpc_channel = original_get_grpc_channel with patch.dict( @@ -138,7 +139,7 @@ def test_grpc_channel_patching(): max_send = 10 * 1024 * 1024 # 10MB max_recv = 12 * 1024 * 1024 # 12MB - app = WorkflowApp( + WorkflowApp( grpc_max_send_message_length=max_send, grpc_max_receive_message_length=max_recv, ) From 66507be3335d0ce6ea013021c28d4c72072ad73e Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 29 Oct 2025 15:38:19 -0500 Subject: [PATCH 08/12] fix: style again Signed-off-by: Samantha Coyle --- dapr_agents/workflow/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dapr_agents/workflow/base.py b/dapr_agents/workflow/base.py index be081755..f71feae5 100644 --- a/dapr_agents/workflow/base.py +++ b/dapr_agents/workflow/base.py @@ -168,14 +168,14 @@ def get_grpc_channel_with_options( if host_address.lower().startswith(protocol): secure_channel = True # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): if host_address.lower().startswith(protocol): secure_channel = False # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break # Create the base channel From 3142fa0bb72b3da63a519ea4b2c9328af2d5f3a1 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Wed, 29 Oct 2025 15:38:40 -0500 Subject: [PATCH 09/12] style: tox -e ruff Signed-off-by: Samantha Coyle --- tests/workflow/test_grpc_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index 23b29123..25072196 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -122,6 +122,7 @@ def test_grpc_channel_patching(): # Keep original reference def original_get_grpc_channel(*_, **__): return "original" + mock_shared.get_grpc_channel = original_get_grpc_channel with patch.dict( From 8d1cc6ac59e032d2df3cc6bdf5de5fd1a5496418 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 30 Oct 2025 09:20:47 -0500 Subject: [PATCH 10/12] fix: update for test to be happy Signed-off-by: Samantha Coyle --- tests/workflow/test_grpc_config.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index 25072196..6dd6097d 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -147,7 +147,8 @@ def original_get_grpc_channel(*_, **__): # Confirm get_grpc_channel was overridden assert callable(mock_shared.get_grpc_channel) - assert mock_shared.get_grpc_channel != original_get_grpc_channel + assert mock_shared.get_grpc_channel is not original_get_grpc_channel + assert getattr(mock_shared.get_grpc_channel, "__name__", "") == "get_grpc_channel_with_options" # Call the patched function test_address = "localhost:50001" From d8a545458ecc541644f6d643cd5af483433313d0 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 30 Oct 2025 11:18:08 -0500 Subject: [PATCH 11/12] style: tox -e ruff Signed-off-by: Samantha Coyle --- tests/workflow/test_grpc_config.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index 6dd6097d..5f505460 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -148,7 +148,10 @@ def original_get_grpc_channel(*_, **__): # Confirm get_grpc_channel was overridden assert callable(mock_shared.get_grpc_channel) assert mock_shared.get_grpc_channel is not original_get_grpc_channel - assert getattr(mock_shared.get_grpc_channel, "__name__", "") == "get_grpc_channel_with_options" + assert ( + getattr(mock_shared.get_grpc_channel, "__name__", "") + == "get_grpc_channel_with_options" + ) # Call the patched function test_address = "localhost:50001" From f6106d0dee635d301869906497cade3b24afaa75 Mon Sep 17 00:00:00 2001 From: Samantha Coyle Date: Thu, 30 Oct 2025 11:28:50 -0500 Subject: [PATCH 12/12] fix: update for test Signed-off-by: Samantha Coyle --- tests/workflow/test_grpc_config.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/workflow/test_grpc_config.py b/tests/workflow/test_grpc_config.py index 5f505460..c8a67cf4 100644 --- a/tests/workflow/test_grpc_config.py +++ b/tests/workflow/test_grpc_config.py @@ -1,6 +1,7 @@ """Tests for gRPC configuration in WorkflowApp.""" import pytest from unittest.mock import MagicMock, patch, call +import types from dapr_agents.workflow.base import WorkflowApp @@ -125,10 +126,18 @@ def original_get_grpc_channel(*_, **__): mock_shared.get_grpc_channel = original_get_grpc_channel + # Create dummy package/module structure so 'from durabletask.internal import shared' works + durabletask_module = types.ModuleType("durabletask") + internal_module = types.ModuleType("durabletask.internal") + setattr(durabletask_module, "internal", internal_module) + setattr(internal_module, "shared", mock_shared) + with patch.dict( "sys.modules", { "grpc": mock_grpc, + "durabletask": durabletask_module, + "durabletask.internal": internal_module, "durabletask.internal.shared": mock_shared, }, ), patch("dapr_agents.workflow.base.WorkflowRuntime"), patch(