diff --git a/README.md b/README.md index 18ae569..751c770 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,42 @@ It also implements the `close()` method, as suggested by the PEP-2049 specification, to support situations where the cursor is wrapped in a `contextmanager.closing()`. +### Arrow table output + +For improved performance and native GeoArrow support, you can configure +the connection to return PyArrow tables instead of pandas DataFrames: + +```python +from wherobots.db import connect +from wherobots.db.constants import OutputFormat, ResultsFormat, GeometryRepresentation +from wherobots.db.region import Region +from wherobots.db.runtime import Runtime + +with connect( + api_key='...', + runtime=Runtime.TINY, + region=Region.AWS_US_WEST_2, + results_format=ResultsFormat.ARROW, + output_format=OutputFormat.ARROW, + geometry_representation=GeometryRepresentation.WKB) as conn: + curr = conn.cursor() + curr.execute("SELECT * FROM buildings LIMIT 1000") + results = curr.fetchall() + + # results is now a pyarrow.Table instead of pandas.DataFrame + print(f"Result type: {type(results)}") + print(f"Schema: {results.schema}") + print(f"Row count: {len(results)}") + + # Convert to pandas only when needed: + # df = results.to_pandas() +``` + +This is particularly beneficial when working with: +* Large datasets (reduced memory usage and faster operations) +* GeoArrow geometries (native spatial data structures) +* Arrow-native downstream processing pipelines + ### Runtime and region selection You can chose the Wherobots runtime you want to use using the `runtime` @@ -87,6 +123,10 @@ users may find useful: * `results_format`: one of the `ResultsFormat` enum values; Arrow encoding is the default and most efficient format for receiving query results. +* `output_format`: one of the `OutputFormat` enum values; controls + whether query results are returned as PyArrow tables (`ARROW`) or + pandas DataFrames (`PANDAS`, default). Use `ARROW` for better + performance with large datasets and native GeoArrow support. * `data_compression`: one of the `DataCompression` enum values; Brotli compression is the default and the most efficient compression algorithm for receiving query results. diff --git a/example_arrow_output.py b/example_arrow_output.py new file mode 100644 index 0000000..96cfb21 --- /dev/null +++ b/example_arrow_output.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +""" +Example script demonstrating Arrow table output functionality. + +This script shows how to use the new output_format parameter to return +PyArrow tables instead of pandas DataFrames, which is particularly useful +for working with GeoArrow geometries. +""" + +from wherobots.db import connect +from wherobots.db.constants import ResultsFormat, OutputFormat, GeometryRepresentation +from wherobots.db.runtime import Runtime +from wherobots.db.region import Region + +def example_arrow_usage(): + """ + Example of how to use the new Arrow output functionality. + + Note: This is a code example only - it would need valid credentials + to actually run against a Wherobots DB instance. + """ + + # Example 1: Return Arrow tables instead of pandas DataFrames + with connect( + host="api.cloud.wherobots.com", + api_key="your_api_key", + runtime=Runtime.TINY, + results_format=ResultsFormat.ARROW, # Efficient wire format + output_format=OutputFormat.ARROW, # Return Arrow tables + geometry_representation=GeometryRepresentation.WKB, + region=Region.AWS_US_WEST_2 + ) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM buildings LIMIT 1000") + results = cursor.fetchall() + + # results is now a pyarrow.Table instead of pandas.DataFrame + print(f"Result type: {type(results)}") + print(f"Schema: {results.schema}") + print(f"Row count: {len(results)}") + + # Work with Arrow table directly (great for GeoArrow!) + # Convert to pandas only when needed: + # df = results.to_pandas() + + # Example 2: Default behavior (backwards compatible) + with connect( + host="api.cloud.wherobots.com", + api_key="your_api_key", + runtime=Runtime.TINY, + results_format=ResultsFormat.ARROW, + # output_format defaults to OutputFormat.PANDAS + geometry_representation=GeometryRepresentation.WKB, + region=Region.AWS_US_WEST_2 + ) as conn: + cursor = conn.cursor() + cursor.execute("SELECT * FROM buildings LIMIT 1000") + results = cursor.fetchall() + + # results is a pandas.DataFrame (existing behavior) + print(f"Result type: {type(results)}") + +if __name__ == "__main__": + print("Arrow Table Output Example") + print("=" * 30) + print("This example shows how to use the new output_format parameter.") + print("Uncomment and provide valid credentials to run against Wherobots DB.") + print() + print("Key benefits of Arrow output:") + print("- More efficient for large datasets") + print("- Native support for GeoArrow geometries") + print("- Better interoperability with Arrow ecosystem") + print("- Zero-copy operations when possible") \ No newline at end of file diff --git a/wherobots/db/__init__.py b/wherobots/db/__init__.py index 3e9a96e..a584728 100644 --- a/wherobots/db/__init__.py +++ b/wherobots/db/__init__.py @@ -10,6 +10,12 @@ ProgrammingError, NotSupportedError, ) +from .constants import ( + OutputFormat, + ResultsFormat, + DataCompression, + GeometryRepresentation, +) from .region import Region from .runtime import Runtime @@ -25,6 +31,10 @@ "OperationalError", "ProgrammingError", "NotSupportedError", + "OutputFormat", + "ResultsFormat", + "DataCompression", + "GeometryRepresentation", "Region", "Runtime", ] diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 47bbf61..d0adde2 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -21,6 +21,7 @@ ResultsFormat, DataCompression, GeometryRepresentation, + OutputFormat, ) from wherobots.db.cursor import Cursor from wherobots.db.errors import NotSupportedError, OperationalError @@ -56,12 +57,14 @@ def __init__( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + output_format: Union[OutputFormat, None] = None, ): self.__ws = ws self.__read_timeout = read_timeout self.__results_format = results_format self.__data_compression = data_compression self.__geometry_representation = geometry_representation + self.__output_format = output_format or OutputFormat.PANDAS self.__queries: dict[str, Query] = {} self.__thread = threading.Thread( @@ -181,7 +184,10 @@ def _handle_results(self, execution_id: str, results: Dict[str, Any]) -> Any: buffer = pyarrow.py_buffer(result_bytes) stream = pyarrow.input_stream(buffer, result_compression) with pyarrow.ipc.open_stream(stream) as reader: - return reader.read_pandas() + if self.__output_format == OutputFormat.ARROW: + return reader.read_all() + else: + return reader.read_pandas() else: return OperationalError(f"Unsupported results format {result_format}") diff --git a/wherobots/db/constants.py b/wherobots/db/constants.py index 95f2555..fcb11ce 100644 --- a/wherobots/db/constants.py +++ b/wherobots/db/constants.py @@ -85,6 +85,11 @@ class GeometryRepresentation(LowercaseStrEnum): GEOJSON = auto() +class OutputFormat(LowercaseStrEnum): + PANDAS = auto() + ARROW = auto() + + class AppStatus(StrEnum): PENDING = auto() PREPARING = auto() diff --git a/wherobots/db/cursor.py b/wherobots/db/cursor.py index 72f9009..7805473 100644 --- a/wherobots/db/cursor.py +++ b/wherobots/db/cursor.py @@ -1,6 +1,9 @@ import queue from typing import Any, Optional, List, Tuple, Dict +import pandas +import pyarrow + from .errors import DatabaseError, ProgrammingError _TYPE_MAP = { @@ -13,6 +16,16 @@ "bytes": "BINARY", } +_ARROW_TYPE_MAP = { + "string": "STRING", + "int64": "NUMBER", + "double": "NUMBER", + "float64": "NUMBER", + "timestamp": "DATETIME", + "bool": "NUMBER", + "binary": "BINARY", +} + class Cursor: def __init__(self, exec_fn, cancel_fn) -> None: @@ -53,21 +66,40 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]: if isinstance(result, DatabaseError): raise result - self.__rowcount = len(result) - self.__results = result - if not result.empty: - self.__description = [ - ( - col_name, # name - _TYPE_MAP.get(str(result[col_name].dtype), "STRING"), # type_code - None, # display_size - result[col_name].memory_usage(), # internal_size - None, # precision - None, # scale - True, # null_ok; Assuming all columns can accept NULL values - ) - for col_name in result.columns - ] + # Handle both Arrow tables and pandas DataFrames + if isinstance(result, pyarrow.Table): + self.__rowcount = len(result) + self.__results = result + if len(result) > 0: + self.__description = [ + ( + col_name, # name + _ARROW_TYPE_MAP.get(str(result.schema.field(col_name).type).split('<')[0], "STRING"), # type_code + None, # display_size + result.column(col_name).nbytes, # internal_size + None, # precision + None, # scale + True, # null_ok; Assuming all columns can accept NULL values + ) + for col_name in result.column_names + ] + else: + # pandas DataFrame + self.__rowcount = len(result) + self.__results = result + if not result.empty: + self.__description = [ + ( + col_name, # name + _TYPE_MAP.get(str(result[col_name].dtype), "STRING"), # type_code + None, # display_size + result[col_name].memory_usage(), # internal_size + None, # precision + None, # scale + True, # null_ok; Assuming all columns can accept NULL values + ) + for col_name in result.columns + ] return self.__results @@ -89,21 +121,72 @@ def executemany( ) -> None: raise NotImplementedError + def __get_row_data(self, results, start_row: int, end_row: int = None) -> List[Any]: + """Helper method to extract row data from either Arrow table or pandas DataFrame.""" + if isinstance(results, pyarrow.Table): + # Convert to pandas for easier row-wise access + # TODO: This could be optimized to avoid conversion for large tables + df = results.to_pandas() + if end_row is None: + return df.iloc[start_row:] + else: + return df.iloc[start_row:end_row] + else: + # pandas DataFrame + if end_row is None: + return results.iloc[start_row:] + else: + return results.iloc[start_row:end_row] + def fetchone(self) -> Any: - results = self.__get_results()[self.__current_row :] - if len(results) == 0: + results = self.__get_results() + if self.__current_row >= len(results): return None - self.__current_row += 1 - return results[0] + + if isinstance(results, pyarrow.Table): + # For Arrow tables, return the native result when fetching + if self.__current_row == 0: + self.__current_row = len(results) # Mark all as consumed + return results + else: + return None + else: + # pandas DataFrame - return single row + row_data = self.__get_row_data(results, self.__current_row, self.__current_row + 1) + if len(row_data) == 0: + return None + self.__current_row += 1 + return row_data.iloc[0] if hasattr(row_data, 'iloc') else row_data def fetchmany(self, size: int = None) -> List[Any]: size = size or self.arraysize - results = self.__get_results()[self.__current_row : self.__current_row + size] - self.__current_row += size - return results + results = self.__get_results() + + if isinstance(results, pyarrow.Table): + # For Arrow tables, return the native result + if self.__current_row == 0: + self.__current_row = len(results) # Mark all as consumed + return results + else: + return [] + else: + # pandas DataFrame + row_data = self.__get_row_data(results, self.__current_row, self.__current_row + size) + self.__current_row += size + return row_data def fetchall(self) -> List[Any]: - return self.__get_results()[self.__current_row :] + results = self.__get_results() + + if isinstance(results, pyarrow.Table): + # For Arrow tables, return the native Arrow table + self.__current_row = len(results) # Mark all as consumed + return results + else: + # pandas DataFrame + row_data = self.__get_row_data(results, self.__current_row) + self.__current_row = len(results) # Mark all as consumed + return row_data def close(self) -> None: """Close the cursor.""" diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index 9f3a6ab..f2f8dde 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -31,6 +31,7 @@ AppStatus, DataCompression, GeometryRepresentation, + OutputFormat, ResultsFormat, SessionType, ) @@ -72,6 +73,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + output_format: Union[OutputFormat, None] = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -157,6 +159,7 @@ def get_session_uri() -> str: results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + output_format=output_format, ) @@ -177,6 +180,7 @@ def connect_direct( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + output_format: Union[OutputFormat, None] = None, ) -> Connection: uri_with_protocol = f"{uri}/{protocol}" @@ -199,4 +203,5 @@ def connect_direct( results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + output_format=output_format, )