Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 144 additions & 25 deletions python/sedona/utils/geoarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from pyspark.sql import DataFrame
from pyspark.sql.types import StructType, StructField, DataType, ArrayType, MapType

try:
from pyspark.util import _load_from_socket
except ImportError:
from pyspark.rdd import _load_from_socket

from sedona.sql.types import GeometryType
import geopandas as gpd
from pyspark.sql.pandas.types import (
Expand All @@ -50,22 +55,138 @@ def dataframe_to_arrow(df, crs=None):
the output if exactly one CRS is present in the output.
:return:
"""
import pyarrow as pa

col_is_geometry = [isinstance(f.dataType, GeometryType) for f in df.schema.fields]

if not any(col_is_geometry):
return dataframe_to_arrow_raw(df)

df_projected = project_dataframe_geoarrow(df, col_is_geometry)
table = dataframe_to_arrow_raw(df_projected)
return wrap_table_or_batch(table, col_is_geometry, crs)


class GeoArrowDataFrameReader:
def __init__(self, df, crs=None):
from pyspark.sql.pandas.types import to_arrow_schema
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.pandas.serializers import ArrowCollectSerializer

self._crs = crs
self._batch_order = None
self._col_is_geometry = [
isinstance(f.dataType, GeometryType) for f in df.schema.fields
]

if any(self._col_is_geometry):
df = project_dataframe_geoarrow(df, self._col_is_geometry)
raw_schema = to_arrow_schema(df.schema)
self._schema = raw_schema_to_geoarrow_schema(
raw_schema, self._col_is_geometry, self._crs
)
else:
self._schema = to_arrow_schema(df.schema)

with SCCallSiteSync(df._sc):
(
port,
auth_secret,
self._jsocket_auth_server,
) = df._jdf.collectAsArrowToPython()

self._batch_stream = _load_from_socket(
(port, auth_secret), ArrowCollectSerializer()
)

@property
def schema(self):
return self._schema

@property
def batch_order(self):
return self._batch_order

def to_table(self):
import pyarrow as pa

batches = list(self)
if not batches:
return pa.Table.from_batches([], schema=self.schema)

batches_in_order = [batches[i] for i in self.batch_order]
return pa.Table.from_batches(batches_in_order)

def __iter__(self):
import pyarrow as pa

try:
for batch_or_indices in self._batch_stream:
if isinstance(batch_or_indices, pa.RecordBatch):
yield wrap_table_or_batch(
batch_or_indices, self._col_is_geometry, self._crs
)
else:
self._batch_order = batch_or_indices
finally:
self._finish_stream()

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
self._finish_stream()

def __del__(self):
self._finish_stream()

def _finish_stream(self):
from pyspark.errors.exceptions.captured import unwrap_spark_exception

if self._jsocket_auth_server is None:
return

with unwrap_spark_exception():
# Join serving thread and raise any exceptions from collectAsArrowToPython
auth_server = self._jsocket_auth_server
self._jsocket_auth_server = None
auth_server.getResult()


def project_dataframe_geoarrow(df, col_is_geometry):
df_columns = list(df)
df_column_names = df.schema.fieldNames()
for i, is_geom in enumerate(col_is_geometry):
if is_geom:
df_columns[i] = ST_AsEWKB(df_columns[i]).alias(df_column_names[i])

df_projected = df.select(*df_columns)
table = dataframe_to_arrow_raw(df_projected)
return df.select(*df_columns)


def raw_schema_to_geoarrow_schema(raw_schema, col_is_geometry, crs, columns=None):
import pyarrow as pa

try:
import geoarrow.types as gat

spec = gat.wkb()
except ImportError:
spec = None

if columns is None:
columns = [None] * len(col_is_geometry)

new_fields = [
(
wrap_geoarrow_field(raw_schema.field(i), columns[i], crs, spec)
if is_geom
else raw_schema.field(i)
)
for i, is_geom in enumerate(col_is_geometry)
]

return pa.schema(new_fields)


def wrap_table_or_batch(table_or_batch, col_is_geometry, crs):
try:
# Using geoarrow-types is the preferred mechanism for Arrow output.
# Using the extension type ensures that the type and its metadata will
Expand All @@ -78,25 +199,19 @@ def dataframe_to_arrow(df, crs=None):

new_cols = [
wrap_geoarrow_extension(col, spec, crs) if is_geom else col
for is_geom, col in zip(col_is_geometry, table.columns)
for is_geom, col in zip(col_is_geometry, table_or_batch.columns)
]

return pa.table(new_cols, table.column_names)
return table_or_batch.from_arrays(new_cols, table_or_batch.column_names)
except ImportError:
# In the event that we don't have access to GeoArrow extension types,
# we can still add field metadata that will propagate through some types
# of operations (e.g., writing this table to a file or passing it to
# DuckDB as long as no intermediate transformations were applied).
new_fields = [
(
wrap_geoarrow_field(table.schema.field(i), table[i], crs)
if is_geom
else table.schema.field(i)
)
for i, is_geom in enumerate(col_is_geometry)
]

return table.from_arrays(table.columns, schema=pa.schema(new_fields))
schema = raw_schema_to_geoarrow_schema(
table_or_batch.schema, col_is_geometry, crs, table_or_batch.columns
)
return table_or_batch.from_arrays(table_or_batch.columns, schema=schema)


def dataframe_to_arrow_raw(df):
Expand Down Expand Up @@ -125,7 +240,7 @@ def dataframe_to_arrow_raw(df):


def wrap_geoarrow_extension(col, spec, crs):
if crs is None:
if crs is None and col is not None:
crs = unique_srid_from_ewkb(col)
elif not hasattr(crs, "to_json"):
import pyproj
Expand All @@ -135,21 +250,25 @@ def wrap_geoarrow_extension(col, spec, crs):
return spec.override(crs=crs).to_pyarrow().wrap_array(col)


def wrap_geoarrow_field(field, col, crs):
if crs is None:
def wrap_geoarrow_field(field, col, crs, spec=None):
if crs is None and col is not None:
crs = unique_srid_from_ewkb(col)

if crs is not None:
metadata = f'"crs": {crs_to_json(crs)}'
else:
metadata = ""

return field.with_metadata(
{
"ARROW:extension:name": "geoarrow.wkb",
"ARROW:extension:metadata": "{" + metadata + "}",
}
)
if spec is None:
return field.with_metadata(
{
"ARROW:extension:name": "geoarrow.wkb",
"ARROW:extension:metadata": "{" + metadata + "}",
}
)
else:
spec_metadata = spec.from_extension_metadata("{" + metadata + "}")
return field.with_type(spec_metadata.coalesce(spec).to_pyarrow())


def crs_to_json(crs):
Expand Down
Loading