Skip to content

Introduce row_limit param #607

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: sea-migration
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/databricks/sql/backend/databricks_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def execute_command(
parameters: List,
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
) -> Union["ResultSet", None]:
"""
Executes a SQL command or query within the specified session.
Expand All @@ -103,6 +104,7 @@ def execute_command(
parameters: List of parameters to bind to the query
async_op: Whether to execute the command asynchronously
enforce_embedded_schema_correctness: Whether to enforce schema correctness
row_limit: Maximum number of rows to fetch overall. Only supported for SEA protocol.

Returns:
If async_op is False, returns a ResultSet object containing the
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def execute_command(
parameters: List[Dict[str, Any]],
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
) -> Union["ResultSet", None]:
"""
Execute a SQL command using the SEA backend.
Expand Down Expand Up @@ -462,7 +463,7 @@ def execute_command(
format=format,
wait_timeout=(WaitTimeout.ASYNC if async_op else WaitTimeout.SYNC).value,
on_wait_timeout="CONTINUE",
row_limit=max_rows,
row_limit=row_limit,
parameters=sea_parameters if sea_parameters else None,
result_compression=result_compression,
)
Expand Down
3 changes: 2 additions & 1 deletion src/databricks/sql/backend/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import math
import time
import threading
from typing import List, Union, Any, TYPE_CHECKING
from typing import List, Optional, Union, Any, TYPE_CHECKING

if TYPE_CHECKING:
from databricks.sql.client import Cursor
Expand Down Expand Up @@ -929,6 +929,7 @@ def execute_command(
parameters=[],
async_op=False,
enforce_embedded_schema_correctness=False,
row_limit: Optional[int] = None,
) -> Union["ResultSet", None]:
thrift_handle = session_id.to_thrift_handle()
if not thrift_handle:
Expand Down
26 changes: 18 additions & 8 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def cursor(
self,
arraysize: int = DEFAULT_ARRAY_SIZE,
buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
row_limit: Optional[int] = None,
) -> "Cursor":
"""
Return a new Cursor object using the connection.
Expand All @@ -355,6 +356,7 @@ def cursor(
self.session.backend,
arraysize=arraysize,
result_buffer_size_bytes=buffer_size_bytes,
row_limit=row_limit,
)
self._cursors.append(cursor)
return cursor
Expand Down Expand Up @@ -388,6 +390,7 @@ def __init__(
backend: DatabricksClient,
result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES,
arraysize: int = DEFAULT_ARRAY_SIZE,
row_limit: Optional[int] = None,
) -> None:
"""
These objects represent a database cursor, which is used to manage the context of a fetch
Expand All @@ -397,16 +400,23 @@ def __init__(
visible by other cursors or connections.
"""

self.connection = connection
self.rowcount = -1 # Return -1 as this is not supported
self.buffer_size_bytes = result_buffer_size_bytes
self.connection: Connection = connection

if not connection.session.use_sea and row_limit is not None:
logger.warning(
"Row limit is only supported for SEA protocol. Ignoring row_limit."
)

self.rowcount: int = -1 # Return -1 as this is not supported
self.buffer_size_bytes: int = result_buffer_size_bytes
self.active_result_set: Union[ResultSet, None] = None
self.arraysize = arraysize
self.arraysize: int = arraysize
self.row_limit: Optional[int] = row_limit
# Note that Cursor closed => active result set closed, but not vice versa
self.open = True
self.executing_command_id = None
self.backend = backend
self.active_command_id = None
self.open: bool = True
self.executing_command_id: Optional[CommandId] = None
self.backend: DatabricksClient = backend
self.active_command_id: Optional[CommandId] = None
self.escaper = ParamEscaper()
self.lastrowid = None

Expand Down
5 changes: 3 additions & 2 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __init__(
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.use_sea = kwargs.get("use_sea", False)
self.backend = self._create_backend(
self.use_sea,
server_hostname,
http_path,
all_headers,
Expand All @@ -89,6 +91,7 @@ def __init__(

def _create_backend(
self,
use_sea: bool,
server_hostname: str,
http_path: str,
all_headers: List[Tuple[str, str]],
Expand All @@ -97,8 +100,6 @@ def _create_backend(
kwargs: dict,
) -> DatabricksClient:
"""Create and return the appropriate backend client."""
use_sea = kwargs.get("use_sea", False)

databricks_client_class: Type[DatabricksClient]
if use_sea:
logger.debug("Creating SEA backend client")
Expand Down
Loading