diff --git a/src/databricks/sql/backend/databricks_client.py b/src/databricks/sql/backend/databricks_client.py index 8fda71e1e..973c2932e 100644 --- a/src/databricks/sql/backend/databricks_client.py +++ b/src/databricks/sql/backend/databricks_client.py @@ -82,7 +82,7 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List, async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 97d25a058..76903ccd2 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -1,23 +1,43 @@ import logging +import time import re -from typing import Dict, Tuple, List, Optional, TYPE_CHECKING, Set +from typing import Any, Dict, Tuple, List, Optional, Union, TYPE_CHECKING, Set + +from databricks.sql.backend.sea.models.base import ResultManifest +from databricks.sql.backend.sea.utils.constants import ( + ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, + ResultFormat, + ResultDisposition, + ResultCompression, + WaitTimeout, +) if TYPE_CHECKING: from databricks.sql.client import Cursor + from databricks.sql.result_set import ResultSet from databricks.sql.backend.databricks_client import DatabricksClient -from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType -from databricks.sql.exc import ServerOperationError -from databricks.sql.backend.sea.utils.http_client import SeaHttpClient -from databricks.sql.backend.sea.utils.constants import ( - ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP, +from databricks.sql.backend.types import ( + SessionId, + CommandId, + CommandState, + BackendType, + ExecuteResponse, ) -from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.exc import DatabaseError, ServerOperationError +from databricks.sql.backend.sea.utils.http_client import SeaHttpClient from databricks.sql.types import SSLOptions from databricks.sql.backend.sea.models import ( + ExecuteStatementRequest, + GetStatementRequest, + CancelStatementRequest, + CloseStatementRequest, CreateSessionRequest, DeleteSessionRequest, + StatementParameter, + ExecuteStatementResponse, + GetStatementResponse, CreateSessionResponse, ) @@ -65,6 +85,9 @@ class SeaDatabricksClient(DatabricksClient): STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}" CANCEL_STATEMENT_PATH_WITH_ID = STATEMENT_PATH + "/{}/cancel" + # SEA constants + POLL_INTERVAL_SECONDS = 0.2 + def __init__( self, server_hostname: str, @@ -262,8 +285,113 @@ def get_allowed_session_configurations() -> List[str]: """ return list(ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP.keys()) - # == Not Implemented Operations == - # These methods will be implemented in future iterations + def _extract_description_from_manifest( + self, manifest: ResultManifest + ) -> Optional[List]: + """ + Extract column description from a manifest object, in the format defined by + the spec: https://peps.python.org/pep-0249/#description + + Args: + manifest: The ResultManifest object containing schema information + + Returns: + Optional[List]: A list of column tuples or None if no columns are found + """ + + schema_data = manifest.schema + columns_data = schema_data.get("columns", []) + + if not columns_data: + return None + + columns = [] + for col_data in columns_data: + # Format: (name, type_code, display_size, internal_size, precision, scale, null_ok) + columns.append( + ( + col_data.get("name", ""), # name + col_data.get("type_name", ""), # type_code + None, # display_size (not provided by SEA) + None, # internal_size (not provided by SEA) + col_data.get("precision"), # precision + col_data.get("scale"), # scale + col_data.get("nullable", True), # null_ok + ) + ) + + return columns if columns else None + + def _results_message_to_execute_response( + self, response: GetStatementResponse + ) -> ExecuteResponse: + """ + Convert a SEA response to an ExecuteResponse and extract result data. + + Args: + sea_response: The response from the SEA API + command_id: The command ID + + Returns: + ExecuteResponse: The normalized execute response + """ + + # Extract description from manifest schema + description = self._extract_description_from_manifest(response.manifest) + + # Check for compression + lz4_compressed = ( + response.manifest.result_compression == ResultCompression.LZ4_FRAME + ) + + execute_response = ExecuteResponse( + command_id=CommandId.from_sea_statement_id(response.statement_id), + status=response.status.state, + description=description, + has_been_closed_server_side=False, + lz4_compressed=lz4_compressed, + is_staging_operation=False, + arrow_schema_bytes=None, + result_format=response.manifest.format, + ) + + return execute_response + + def _check_command_not_in_failed_or_closed_state( + self, state: CommandState, command_id: CommandId + ) -> None: + if state == CommandState.CLOSED: + raise DatabaseError( + "Command {} unexpectedly closed server side".format(command_id), + { + "operation-id": command_id, + }, + ) + if state == CommandState.FAILED: + raise ServerOperationError( + "Command {} failed".format(command_id), + { + "operation-id": command_id, + }, + ) + + def _wait_until_command_done( + self, response: ExecuteStatementResponse + ) -> CommandState: + """ + Wait until a command is done. + """ + + state = response.status.state + command_id = CommandId.from_sea_statement_id(response.statement_id) + + while state in [CommandState.PENDING, CommandState.RUNNING]: + time.sleep(self.POLL_INTERVAL_SECONDS) + state = self.get_query_state(command_id) + + self._check_command_not_in_failed_or_closed_state(state, command_id) + + return state def execute_command( self, @@ -274,41 +402,221 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[ttypes.TSparkParameter], + parameters: List[Dict[str, Any]], async_op: bool, enforce_embedded_schema_correctness: bool, - ): - """Not implemented yet.""" - raise NotImplementedError( - "execute_command is not yet implemented for SEA backend" + ) -> Union["ResultSet", None]: + """ + Execute a SQL command using the SEA backend. + + Args: + operation: SQL command to execute + session_id: Session identifier + max_rows: Maximum number of rows to fetch + max_bytes: Maximum number of bytes to fetch + lz4_compression: Whether to use LZ4 compression + cursor: Cursor executing the command + use_cloud_fetch: Whether to use cloud fetch + parameters: SQL parameters + async_op: Whether to execute asynchronously + enforce_embedded_schema_correctness: Whether to enforce schema correctness + + Returns: + ResultSet: A SeaResultSet instance for the executed command + """ + + if session_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA session ID") + + sea_session_id = session_id.to_sea_session_id() + + # Convert parameters to StatementParameter objects + sea_parameters = [] + if parameters: + for param in parameters: + sea_parameters.append( + StatementParameter( + name=param["name"], + value=param["value"], + type=param["type"] if "type" in param else None, + ) + ) + + format = ( + ResultFormat.ARROW_STREAM if use_cloud_fetch else ResultFormat.JSON_ARRAY + ).value + disposition = ( + ResultDisposition.EXTERNAL_LINKS + if use_cloud_fetch + else ResultDisposition.INLINE + ).value + result_compression = ( + ResultCompression.LZ4_FRAME if lz4_compression else ResultCompression.NONE + ).value + + request = ExecuteStatementRequest( + warehouse_id=self.warehouse_id, + session_id=sea_session_id, + statement=operation, + disposition=disposition, + format=format, + wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value, + on_wait_timeout="CONTINUE", + row_limit=max_rows, + parameters=sea_parameters if sea_parameters else None, + result_compression=result_compression, + ) + + response_data = self.http_client._make_request( + method="POST", path=self.STATEMENT_PATH, data=request.to_dict() ) + response = ExecuteStatementResponse.from_dict(response_data) + statement_id = response.statement_id + if not statement_id: + raise ServerOperationError( + "Failed to execute command: No statement ID returned", + { + "operation-id": None, + "diagnostic-info": None, + }, + ) + + command_id = CommandId.from_sea_statement_id(statement_id) + + # Store the command ID in the cursor + cursor.active_command_id = command_id + + # If async operation, return and let the client poll for results + if async_op: + return None + + self._wait_until_command_done(response) + return self.get_execution_result(command_id, cursor) def cancel_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "cancel_command is not yet implemented for SEA backend" + """ + Cancel a running command. + + Args: + command_id: Command identifier to cancel + + Raises: + ValueError: If the command ID is invalid + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CancelStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="POST", + path=self.CANCEL_STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def close_command(self, command_id: CommandId) -> None: - """Not implemented yet.""" - raise NotImplementedError( - "close_command is not yet implemented for SEA backend" + """ + Close a command and release resources. + + Args: + command_id: Command identifier to close + + Raises: + ValueError: If the command ID is invalid + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = CloseStatementRequest(statement_id=sea_statement_id) + self.http_client._make_request( + method="DELETE", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) def get_query_state(self, command_id: CommandId) -> CommandState: - """Not implemented yet.""" - raise NotImplementedError( - "get_query_state is not yet implemented for SEA backend" + """ + Get the state of a running query. + + Args: + command_id: Command identifier + + Returns: + CommandState: The current state of the command + + Raises: + ValueError: If the command ID is invalid + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + request = GetStatementRequest(statement_id=sea_statement_id) + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), ) + # Parse the response + response = GetStatementResponse.from_dict(response_data) + return response.status.state + def get_execution_result( self, command_id: CommandId, cursor: "Cursor", - ): - """Not implemented yet.""" - raise NotImplementedError( - "get_execution_result is not yet implemented for SEA backend" + ) -> "ResultSet": + """ + Get the result of a command execution. + + Args: + command_id: Command identifier + cursor: Cursor executing the command + + Returns: + ResultSet: A SeaResultSet instance with the execution results + + Raises: + ValueError: If the command ID is invalid + """ + + if command_id.backend_type != BackendType.SEA: + raise ValueError("Not a valid SEA command ID") + + sea_statement_id = command_id.to_sea_statement_id() + + # Create the request model + request = GetStatementRequest(statement_id=sea_statement_id) + + # Get the statement result + response_data = self.http_client._make_request( + method="GET", + path=self.STATEMENT_PATH_WITH_ID.format(sea_statement_id), + data=request.to_dict(), + ) + response = GetStatementResponse.from_dict(response_data) + + # Create and return a SeaResultSet + from databricks.sql.result_set import SeaResultSet + + execute_response = self._results_message_to_execute_response(response) + + return SeaResultSet( + connection=cursor.connection, + execute_response=execute_response, + sea_client=self, + buffer_size_bytes=cursor.buffer_size_bytes, + arraysize=cursor.arraysize, + result_data=response.result, + manifest=response.manifest, ) # == Metadata Operations == diff --git a/src/databricks/sql/backend/sea/utils/constants.py b/src/databricks/sql/backend/sea/utils/constants.py index 9160ef6ad..7481a90db 100644 --- a/src/databricks/sql/backend/sea/utils/constants.py +++ b/src/databricks/sql/backend/sea/utils/constants.py @@ -3,6 +3,7 @@ """ from typing import Dict +from enum import Enum # from https://docs.databricks.com/aws/en/sql/language-manual/sql-ref-parameters ALLOWED_SESSION_CONF_TO_DEFAULT_VALUES_MAP: Dict[str, str] = { @@ -15,3 +16,32 @@ "TIMEZONE": "UTC", "USE_CACHED_RESULT": "true", } + + +class ResultFormat(Enum): + """Enum for result format values.""" + + ARROW_STREAM = "ARROW_STREAM" + JSON_ARRAY = "JSON_ARRAY" + + +class ResultDisposition(Enum): + """Enum for result disposition values.""" + + # TODO: add support for hybrid disposition + EXTERNAL_LINKS = "EXTERNAL_LINKS" + INLINE = "INLINE" + + +class ResultCompression(Enum): + """Enum for result compression values.""" + + LZ4_FRAME = "LZ4_FRAME" + NONE = None + + +class WaitTimeout(Enum): + """Enum for wait timeout values.""" + + ASYNC = "0s" + SYNC = "10s" diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index bc2688a68..f30c92ed0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -1,11 +1,26 @@ +""" +Tests for the SEA (Statement Execution API) backend implementation. + +This module contains tests for the SeaDatabricksClient class, which implements +the Databricks SQL connector's SEA backend functionality. +""" + import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock -from databricks.sql.backend.sea.backend import SeaDatabricksClient -from databricks.sql.backend.types import SessionId, BackendType +from databricks.sql.backend.sea.backend import ( + SeaDatabricksClient, + _filter_session_configuration, +) +from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider -from databricks.sql.exc import Error +from databricks.sql.exc import ( + Error, + NotSupportedError, + ServerOperationError, + DatabaseError, +) class TestSeaBackend: @@ -41,8 +56,43 @@ def sea_client(self, mock_http_client): return client - def test_init_extracts_warehouse_id(self, mock_http_client): - """Test that the constructor properly extracts the warehouse ID from the HTTP path.""" + @pytest.fixture + def sea_session_id(self): + """Create a SEA session ID.""" + return SessionId.from_sea_session_id("test-session-123") + + @pytest.fixture + def sea_command_id(self): + """Create a SEA command ID.""" + return CommandId.from_sea_statement_id("test-statement-123") + + @pytest.fixture + def mock_cursor(self): + """Create a mock cursor.""" + cursor = Mock() + cursor.active_command_id = None + cursor.buffer_size_bytes = 1000 + cursor.arraysize = 100 + return cursor + + @pytest.fixture + def thrift_session_id(self): + """Create a Thrift session ID (not SEA).""" + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + return SessionId.from_thrift_handle(mock_thrift_handle) + + @pytest.fixture + def thrift_command_id(self): + """Create a Thrift command ID (not SEA).""" + mock_thrift_operation_handle = MagicMock() + mock_thrift_operation_handle.operationId.guid = b"guid" + mock_thrift_operation_handle.operationId.secret = b"secret" + return CommandId.from_thrift_handle(mock_thrift_operation_handle) + + def test_initialization(self, mock_http_client): + """Test client initialization and warehouse ID extraction.""" # Test with warehouses format client1 = SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -53,6 +103,7 @@ def test_init_extracts_warehouse_id(self, mock_http_client): ssl_options=SSLOptions(), ) assert client1.warehouse_id == "abc123" + assert client1.max_download_threads == 10 # Default value # Test with endpoints format client2 = SeaDatabricksClient( @@ -65,8 +116,19 @@ def test_init_extracts_warehouse_id(self, mock_http_client): ) assert client2.warehouse_id == "def456" - def test_init_raises_error_for_invalid_http_path(self, mock_http_client): - """Test that the constructor raises an error for invalid HTTP paths.""" + # Test with custom max_download_threads + client3 = SeaDatabricksClient( + server_hostname="test-server.databricks.com", + port=443, + http_path="/sql/warehouses/abc123", + http_headers=[], + auth_provider=AuthProvider(), + ssl_options=SSLOptions(), + max_download_threads=5, + ) + assert client3.max_download_threads == 5 + + # Test with invalid HTTP path with pytest.raises(ValueError) as excinfo: SeaDatabricksClient( server_hostname="test-server.databricks.com", @@ -78,30 +140,21 @@ def test_init_raises_error_for_invalid_http_path(self, mock_http_client): ) assert "Could not extract warehouse ID" in str(excinfo.value) - def test_open_session_basic(self, sea_client, mock_http_client): - """Test the open_session method with minimal parameters.""" - # Set up mock response + def test_session_management(self, sea_client, mock_http_client, thrift_session_id): + """Test session management methods.""" + # Test open_session with minimal parameters mock_http_client._make_request.return_value = {"session_id": "test-session-123"} - - # Call the method session_id = sea_client.open_session(None, None, None) - - # Verify the result assert isinstance(session_id, SessionId) assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-123" - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="POST", path=sea_client.SESSION_PATH, data={"warehouse_id": "abc123"} ) - def test_open_session_with_all_parameters(self, sea_client, mock_http_client): - """Test the open_session method with all parameters.""" - # Set up mock response + # Test open_session with all parameters + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {"session_id": "test-session-456"} - - # Call the method with all parameters, including both supported and unsupported configurations session_config = { "ANSI_MODE": "FALSE", # Supported parameter "STATEMENT_TIMEOUT": "3600", # Supported parameter @@ -109,16 +162,8 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): } catalog = "test_catalog" schema = "test_schema" - session_id = sea_client.open_session(session_config, catalog, schema) - - # Verify the result - assert isinstance(session_id, SessionId) - assert session_id.backend_type == BackendType.SEA assert session_id.guid == "test-session-456" - - # Verify the HTTP request - only supported parameters should be included - # and keys should be in lowercase expected_data = { "warehouse_id": "abc123", "session_confs": { @@ -128,156 +173,513 @@ def test_open_session_with_all_parameters(self, sea_client, mock_http_client): "catalog": catalog, "schema": schema, } - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="POST", path=sea_client.SESSION_PATH, data=expected_data ) - def test_open_session_error_handling(self, sea_client, mock_http_client): - """Test error handling in the open_session method.""" - # Set up mock response without session_id + # Test open_session error handling + mock_http_client.reset_mock() mock_http_client._make_request.return_value = {} - - # Call the method and expect an error with pytest.raises(Error) as excinfo: sea_client.open_session(None, None, None) - assert "Failed to create session" in str(excinfo.value) - def test_close_session_valid_id(self, sea_client, mock_http_client): - """Test closing a session with a valid session ID.""" - # Create a valid SEA session ID + # Test close_session with valid ID + mock_http_client.reset_mock() session_id = SessionId.from_sea_session_id("test-session-789") - - # Set up mock response - mock_http_client._make_request.return_value = {} - - # Call the method sea_client.close_session(session_id) - - # Verify the HTTP request - mock_http_client._make_request.assert_called_once_with( + mock_http_client._make_request.assert_called_with( method="DELETE", path=sea_client.SESSION_PATH_WITH_ID.format("test-session-789"), data={"session_id": "test-session-789", "warehouse_id": "abc123"}, ) - def test_close_session_invalid_id_type(self, sea_client): - """Test closing a session with an invalid session ID type.""" - # Create a Thrift session ID (not SEA) - mock_thrift_handle = MagicMock() - mock_thrift_handle.sessionId.guid = b"guid" - mock_thrift_handle.sessionId.secret = b"secret" - session_id = SessionId.from_thrift_handle(mock_thrift_handle) + # Test close_session with invalid ID type + with pytest.raises(ValueError) as excinfo: + sea_client.close_session(thrift_session_id) + assert "Not a valid SEA session ID" in str(excinfo.value) + + def test_command_execution_sync( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test synchronous command execution.""" + # Test synchronous execution + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "schema": [ + { + "name": "col1", + "type_name": "STRING", + "type_text": "string", + "nullable": True, + } + ], + "total_row_count": 1, + "total_byte_count": 100, + }, + "result": {"data": [["value1"]]}, + } + mock_http_client._make_request.return_value = execute_response + + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result == "mock_result_set" + cmd_id_arg = mock_get_result.call_args[0][0] + assert isinstance(cmd_id_arg, CommandId) + assert cmd_id_arg.guid == "test-statement-123" - # Call the method and expect an error + # Test with invalid session ID with pytest.raises(ValueError) as excinfo: - sea_client.close_session(session_id) + mock_thrift_handle = MagicMock() + mock_thrift_handle.sessionId.guid = b"guid" + mock_thrift_handle.sessionId.secret = b"secret" + thrift_session_id = SessionId.from_thrift_handle(mock_thrift_handle) + sea_client.execute_command( + operation="SELECT 1", + session_id=thrift_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) assert "Not a valid SEA session ID" in str(excinfo.value) - def test_session_configuration_helpers(self): - """Test the session configuration helper methods.""" - # Test getting default value for a supported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "ANSI_MODE" + def test_command_execution_async( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test asynchronous command execution.""" + # Test asynchronous execution + execute_response = { + "statement_id": "test-statement-456", + "status": {"state": "PENDING"}, + } + mock_http_client._make_request.return_value = execute_response + + result = sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, + enforce_embedded_schema_correctness=False, ) - assert default_value == "true" - - # Test getting default value for an unsupported parameter - default_value = SeaDatabricksClient.get_default_session_configuration_value( - "UNSUPPORTED_PARAM" + assert result is None + assert isinstance(mock_cursor.active_command_id, CommandId) + assert mock_cursor.active_command_id.guid == "test-statement-456" + + # Test async with missing statement ID + mock_http_client.reset_mock() + mock_http_client._make_request.return_value = {"status": {"state": "PENDING"}} + with pytest.raises(ServerOperationError) as excinfo: + sea_client.execute_command( + operation="SELECT 1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=True, + enforce_embedded_schema_correctness=False, + ) + assert "Failed to execute command: No statement ID returned" in str( + excinfo.value ) - assert default_value is None - # Test getting the list of allowed configurations - allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() - - expected_keys = { - "ANSI_MODE", - "ENABLE_PHOTON", - "LEGACY_TIME_PARSER_POLICY", - "MAX_FILE_PARTITION_BYTES", - "READ_ONLY_EXTERNAL_METASTORE", - "STATEMENT_TIMEOUT", - "TIMEZONE", - "USE_CACHED_RESULT", + def test_command_execution_advanced( + self, sea_client, mock_http_client, mock_cursor, sea_session_id + ): + """Test advanced command execution scenarios.""" + # Test with polling + initial_response = { + "statement_id": "test-statement-789", + "status": {"state": "RUNNING"}, } - assert set(allowed_configs) == expected_keys - - def test_unimplemented_methods(self, sea_client): - """Test that unimplemented methods raise NotImplementedError.""" - # Create dummy parameters for testing - session_id = SessionId.from_sea_session_id("test-session") - command_id = MagicMock() - cursor = MagicMock() + poll_response = { + "statement_id": "test-statement-789", + "status": {"state": "SUCCEEDED"}, + "manifest": {"schema": [], "total_row_count": 0, "total_byte_count": 0}, + "result": {"data": []}, + } + mock_http_client._make_request.side_effect = [initial_response, poll_response] + + with patch.object( + sea_client, "get_execution_result", return_value="mock_result_set" + ) as mock_get_result: + with patch("time.sleep"): + result = sea_client.execute_command( + operation="SELECT * FROM large_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert result == "mock_result_set" + + # Test with parameters + mock_http_client.reset_mock() + mock_http_client._make_request.side_effect = None # Reset side_effect + execute_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + } + mock_http_client._make_request.return_value = execute_response + param = {"name": "param1", "value": "value1", "type": "STRING"} - # Test execute_command - with pytest.raises(NotImplementedError) as excinfo: + with patch.object(sea_client, "get_execution_result"): + sea_client.execute_command( + operation="SELECT * FROM table WHERE col = :param1", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[param], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + args, kwargs = mock_http_client._make_request.call_args + assert "parameters" in kwargs["data"] + assert len(kwargs["data"]["parameters"]) == 1 + assert kwargs["data"]["parameters"][0]["name"] == "param1" + assert kwargs["data"]["parameters"][0]["value"] == "value1" + assert kwargs["data"]["parameters"][0]["type"] == "STRING" + + # Test execution failure + mock_http_client.reset_mock() + error_response = { + "statement_id": "test-statement-123", + "status": { + "state": "FAILED", + "error": { + "message": "Syntax error in SQL", + "error_code": "SYNTAX_ERROR", + }, + }, + } + mock_http_client._make_request.return_value = error_response + + with patch("time.sleep"): + with patch.object( + sea_client, "get_query_state", return_value=CommandState.FAILED + ): + with pytest.raises(Error) as excinfo: + sea_client.execute_command( + operation="SELECT * FROM nonexistent_table", + session_id=sea_session_id, + max_rows=100, + max_bytes=1000, + lz4_compression=False, + cursor=mock_cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + assert "Command test-statement-123 failed" in str(excinfo.value) + + # Test missing statement ID + mock_http_client.reset_mock() + mock_http_client._make_request.return_value = {"status": {"state": "SUCCEEDED"}} + with pytest.raises(ServerOperationError) as excinfo: sea_client.execute_command( operation="SELECT 1", - session_id=session_id, + session_id=sea_session_id, max_rows=100, max_bytes=1000, lz4_compression=False, - cursor=cursor, + cursor=mock_cursor, use_cloud_fetch=False, parameters=[], async_op=False, enforce_embedded_schema_correctness=False, ) - assert "execute_command is not yet implemented" in str(excinfo.value) + assert "Failed to execute command: No statement ID returned" in str( + excinfo.value + ) + def test_command_management( + self, + sea_client, + mock_http_client, + sea_command_id, + thrift_command_id, + mock_cursor, + ): + """Test command management methods.""" # Test cancel_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.cancel_command(command_id) - assert "cancel_command is not yet implemented" in str(excinfo.value) + mock_http_client._make_request.return_value = {} + sea_client.cancel_command(sea_command_id) + mock_http_client._make_request.assert_called_with( + method="POST", + path=sea_client.CANCEL_STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, + ) + + # Test cancel_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.cancel_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) # Test close_command - with pytest.raises(NotImplementedError) as excinfo: - sea_client.close_command(command_id) - assert "close_command is not yet implemented" in str(excinfo.value) + mock_http_client.reset_mock() + sea_client.close_command(sea_command_id) + mock_http_client._make_request.assert_called_with( + method="DELETE", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, + ) + + # Test close_command with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.close_command(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) # Test get_query_state - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_query_state(command_id) - assert "get_query_state is not yet implemented" in str(excinfo.value) + mock_http_client.reset_mock() + mock_http_client._make_request.return_value = { + "statement_id": "test-statement-123", + "status": {"state": "RUNNING"}, + } + state = sea_client.get_query_state(sea_command_id) + assert state == CommandState.RUNNING + mock_http_client._make_request.assert_called_with( + method="GET", + path=sea_client.STATEMENT_PATH_WITH_ID.format("test-statement-123"), + data={"statement_id": "test-statement-123"}, + ) + + # Test get_query_state with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.get_query_state(thrift_command_id) + assert "Not a valid SEA command ID" in str(excinfo.value) # Test get_execution_result - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_execution_result(command_id, cursor) - assert "get_execution_result is not yet implemented" in str(excinfo.value) + mock_http_client.reset_mock() + sea_response = { + "statement_id": "test-statement-123", + "status": {"state": "SUCCEEDED"}, + "manifest": { + "format": "JSON_ARRAY", + "schema": { + "column_count": 1, + "columns": [ + { + "name": "test_value", + "type_text": "INT", + "type_name": "INT", + "position": 0, + } + ], + }, + "total_chunk_count": 1, + "chunks": [{"chunk_index": 0, "row_offset": 0, "row_count": 1}], + "total_row_count": 1, + "truncated": False, + }, + "result": { + "chunk_index": 0, + "row_offset": 0, + "row_count": 1, + "data_array": [["1"]], + }, + } + mock_http_client._make_request.return_value = sea_response + result = sea_client.get_execution_result(sea_command_id, mock_cursor) + assert result.command_id.to_sea_statement_id() == "test-statement-123" + assert result.status == CommandState.SUCCEEDED - # Test metadata operations - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_catalogs(session_id, 100, 1000, cursor) - assert "get_catalogs is not yet implemented" in str(excinfo.value) + # Test get_execution_result with invalid ID + with pytest.raises(ValueError) as excinfo: + sea_client.get_execution_result(thrift_command_id, mock_cursor) + assert "Not a valid SEA command ID" in str(excinfo.value) + + def test_check_command_state(self, sea_client, sea_command_id): + """Test _check_command_not_in_failed_or_closed_state method.""" + # Test with RUNNING state (should not raise) + sea_client._check_command_not_in_failed_or_closed_state( + CommandState.RUNNING, sea_command_id + ) - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_schemas(session_id, 100, 1000, cursor) - assert "get_schemas is not yet implemented" in str(excinfo.value) + # Test with SUCCEEDED state (should not raise) + sea_client._check_command_not_in_failed_or_closed_state( + CommandState.SUCCEEDED, sea_command_id + ) - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_tables(session_id, 100, 1000, cursor) - assert "get_tables is not yet implemented" in str(excinfo.value) + # Test with CLOSED state (should raise DatabaseError) + with pytest.raises(DatabaseError) as excinfo: + sea_client._check_command_not_in_failed_or_closed_state( + CommandState.CLOSED, sea_command_id + ) + assert "Command test-statement-123 unexpectedly closed server side" in str( + excinfo.value + ) - with pytest.raises(NotImplementedError) as excinfo: - sea_client.get_columns(session_id, 100, 1000, cursor) - assert "get_columns is not yet implemented" in str(excinfo.value) + # Test with FAILED state (should raise ServerOperationError) + with pytest.raises(ServerOperationError) as excinfo: + sea_client._check_command_not_in_failed_or_closed_state( + CommandState.FAILED, sea_command_id + ) + assert "Command test-statement-123 failed" in str(excinfo.value) - def test_max_download_threads_property(self, sea_client): - """Test the max_download_threads property.""" - assert sea_client.max_download_threads == 10 + def test_utility_methods(self, sea_client): + """Test utility methods.""" + # Test get_default_session_configuration_value + value = SeaDatabricksClient.get_default_session_configuration_value("ANSI_MODE") + assert value == "true" - # Create a client with a custom value - custom_client = SeaDatabricksClient( - server_hostname="test-server.databricks.com", - port=443, - http_path="/sql/warehouses/abc123", - http_headers=[], - auth_provider=AuthProvider(), - ssl_options=SSLOptions(), - max_download_threads=20, + # Test with unsupported configuration parameter + value = SeaDatabricksClient.get_default_session_configuration_value( + "UNSUPPORTED_PARAM" + ) + assert value is None + + # Test with case-insensitive parameter name + value = SeaDatabricksClient.get_default_session_configuration_value("ansi_mode") + assert value == "true" + + # Test get_allowed_session_configurations + configs = SeaDatabricksClient.get_allowed_session_configurations() + assert isinstance(configs, list) + assert len(configs) > 0 + assert "ANSI_MODE" in configs + + # Test getting the list of allowed configurations with specific keys + allowed_configs = SeaDatabricksClient.get_allowed_session_configurations() + expected_keys = { + "ANSI_MODE", + "ENABLE_PHOTON", + "LEGACY_TIME_PARSER_POLICY", + "MAX_FILE_PARTITION_BYTES", + "READ_ONLY_EXTERNAL_METASTORE", + "STATEMENT_TIMEOUT", + "TIMEZONE", + "USE_CACHED_RESULT", + } + assert set(allowed_configs) == expected_keys + + # Test _extract_description_from_manifest + manifest_obj = MagicMock() + manifest_obj.schema = { + "columns": [ + { + "name": "col1", + "type_name": "STRING", + "precision": 10, + "scale": 2, + "nullable": True, + }, + { + "name": "col2", + "type_name": "INT", + "nullable": False, + }, + ] + } + + description = sea_client._extract_description_from_manifest(manifest_obj) + assert description is not None + assert len(description) == 2 + assert description[0][0] == "col1" # name + assert description[0][1] == "STRING" # type_code + assert description[0][4] == 10 # precision + assert description[0][5] == 2 # scale + assert description[0][6] is True # null_ok + assert description[1][0] == "col2" # name + assert description[1][1] == "INT" # type_code + assert description[1][6] is False # null_ok + + # Test _extract_description_from_manifest with empty columns + empty_manifest = MagicMock() + empty_manifest.schema = {"columns": []} + assert sea_client._extract_description_from_manifest(empty_manifest) is None + + # Test _extract_description_from_manifest with no columns key + no_columns_manifest = MagicMock() + no_columns_manifest.schema = {} + assert ( + sea_client._extract_description_from_manifest(no_columns_manifest) is None ) - # Verify the custom value is returned - assert custom_client.max_download_threads == 20 + def test_unimplemented_metadata_methods( + self, sea_client, sea_session_id, mock_cursor + ): + """Test that metadata methods raise NotImplementedError.""" + # Test get_catalogs + with pytest.raises(NotImplementedError): + sea_client.get_catalogs(sea_session_id, 100, 1000, mock_cursor) + + # Test get_schemas + with pytest.raises(NotImplementedError): + sea_client.get_schemas(sea_session_id, 100, 1000, mock_cursor) + + # Test get_schemas with optional parameters + with pytest.raises(NotImplementedError): + sea_client.get_schemas( + sea_session_id, 100, 1000, mock_cursor, "catalog", "schema" + ) + + # Test get_tables + with pytest.raises(NotImplementedError): + sea_client.get_tables(sea_session_id, 100, 1000, mock_cursor) + + # Test get_tables with optional parameters + with pytest.raises(NotImplementedError): + sea_client.get_tables( + sea_session_id, + 100, + 1000, + mock_cursor, + catalog_name="catalog", + schema_name="schema", + table_name="table", + table_types=["TABLE", "VIEW"], + ) + + # Test get_columns + with pytest.raises(NotImplementedError): + sea_client.get_columns(sea_session_id, 100, 1000, mock_cursor) + + # Test get_columns with optional parameters + with pytest.raises(NotImplementedError): + sea_client.get_columns( + sea_session_id, + 100, + 1000, + mock_cursor, + catalog_name="catalog", + schema_name="schema", + table_name="table", + column_name="column", + )