diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py new file mode 100644 index 000000000..cf90376bf --- /dev/null +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -0,0 +1,459 @@ +"""``AbstractVersionedDataset`` implementation to access Spark dataframes using +``pyspark``. +""" + +from __future__ import annotations + +import json +import logging +import os +import warnings +from functools import partial +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any + +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_protocol_and_path, +) + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, SparkSession + from pyspark.sql.types import StructType + +from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + deployed_on_databricks, + get_dbutils, + strip_dbfs_prefix, +) + +logger = logging.getLogger(__name__) + + +class SparkDatasetV2(AbstractVersionedDataset): + """``SparkDatasetV2`` loads and saves Spark dataframes. + + Examples: + Using the [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + weather: + type: spark.SparkDatasetV2 + filepath: s3a://your_bucket/data/01_raw/weather/* + file_format: csv + load_args: + header: True + inferSchema: True + save_args: + sep: '|' + header: True + + weather_with_schema: + type: spark.SparkDatasetV2 + filepath: s3a://your_bucket/data/01_raw/weather/* + file_format: csv + load_args: + header: True + schema: + filepath: path/to/schema.json + save_args: + sep: '|' + header: True + + weather_cleaned: + type: spark.SparkDatasetV2 + filepath: data/02_intermediate/data.parquet + file_format: parquet + ``` + + Using the [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + + >>> from kedro_datasets.spark import SparkDatasetV2 + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType + >>> + >>> schema = StructType( + ... [StructField("name", StringType(), True), StructField("age", IntegerType(), True)] + ... ) + >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + >>> + >>> dataset = SparkDatasetV2(filepath="tmp_path/test_data") + >>> dataset.save(spark_df) + >>> reloaded = dataset.load() + >>> assert Row(name="Bob", age=12) in reloaded.take(4) + + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # ``ThreadRunner`` instead + _SINGLE_PROCESS = True + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + file_format: str = "parquet", + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + version: Version | None = None, + credentials: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ): + self.file_format = file_format + self.load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} + self.save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} + self.credentials = credentials or {} + self.metadata = metadata + + # Parse filepath + self.protocol, self.path = self._parse_filepath(filepath) + + # Validate and warn about Databricks paths + self._validate_databricks_path(filepath) + + # Get filesystem for metadata operations (exists, glob) + exists_function, glob_function = self._get_filesystem_ops() + + # Store Spark compatible path for I/O + self._spark_path = self._to_spark_path(filepath) + + # Handle schema if provided + self._schema = self._load_schema_from_file(self.load_args.pop("schema", None)) + + super().__init__( + filepath=PurePosixPath(self.path), + version=version, + exists_function=exists_function, + glob_function=glob_function, + ) + + self._validate_delta_format() + + def _parse_filepath(self, filepath: str) -> tuple[str, str]: + """Parse filepath handling special cases like DBFS.""" + # Handle DBFS paths + if filepath.startswith("/dbfs/"): + # /dbfs/path -> dbfs protocol with /path + return "dbfs", filepath[6:] # Remove /dbfs prefix + elif filepath.startswith("dbfs:/"): + # dbfs:/path -> already in correct format + return get_protocol_and_path(filepath) + elif filepath.startswith("/Volumes"): + # Unity Catalog volumes + return "file", filepath + else: + return get_protocol_and_path(filepath) + + def _validate_databricks_path(self, filepath: str) -> None: + """Warn about potential Databricks path issues.""" + if ( + deployed_on_databricks() + and not ( + filepath.startswith("/dbfs") + or filepath.startswith("dbfs:/") + or filepath.startswith("/Volumes") + ) + and not any( + filepath.startswith(f"{p}://") + for p in ["s3", "s3a", "s3n", "gs", "abfs", "wasbs"] + ) + ): + logger.warning( + "Using SparkDatasetV2 on Databricks without the `/dbfs/`, `dbfs:/`, or `/Volumes` prefix " + "in the filepath is a known source of error. You must add this prefix to %s", + filepath, + ) + + def _get_filesystem_ops(self) -> tuple: + """Get filesystem operations with DBFS optimization.""" + # Special handling for DBFS to avoid performance issues + # This addresses the critical performance issue raised by deepyaman + if self.protocol == "dbfs" and deployed_on_databricks(): + try: + spark = self._get_spark() + dbutils = get_dbutils(spark) + if dbutils: + logger.debug("Using optimized DBFS operations via dbutils") + return ( + partial(dbfs_exists, dbutils=dbutils), + partial(dbfs_glob, dbutils=dbutils), + ) + except Exception as e: + logger.warning(f"Failed to get dbutils, falling back to fsspec: {e}") + + # Regular fsspec for everything else + fs = self._get_filesystem() + return fs.exists, fs.glob + + def _get_filesystem(self): + """Get fsspec filesystem with helpful errors for missing deps.""" + try: + import fsspec # noqa: PLC0415 + except ImportError: + raise ImportError( + "fsspec is required for SparkDatasetV2. " + "Install with: pip install fsspec" + ) + + # Map Spark protocols to fsspec protocols + protocol_map = { + "s3a": "s3", + "s3n": "s3", + "dbfs": "file", # DBFS is mounted as local when using fsspec + "": "file", + } + + fsspec_protocol = protocol_map.get(self.protocol, self.protocol) + + try: + return fsspec.filesystem(fsspec_protocol, **self.credentials) + except ImportError as e: + # Provide targeted help for missing filesystem implementations + error_msg = str(e) + if "s3fs" in error_msg: + msg = ( + "s3fs not installed. Install with:\n" + " pip install 'kedro-datasets[spark-s3]' or\n" + " pip install s3fs" + ) + elif "gcsfs" in error_msg: + msg = ( + "gcsfs not installed. Install with:\n" + " pip install 'kedro-datasets[spark-gcs]' or\n" + " pip install gcsfs" + ) + elif "adlfs" in error_msg or "azure" in error_msg: + msg = ( + "adlfs not installed for Azure. Install with:\n" + " pip install 'kedro-datasets[spark-azure]' or\n" + " pip install adlfs" + ) + else: + msg = f"Missing filesystem implementation: {error_msg}" + raise ImportError(msg) from e + + def _to_spark_path(self, filepath: str) -> str: + """Convert to Spark-compatible path format.""" + filepath = str(filepath) + + # Apply DBFS prefix stripping for consistency + filepath = strip_dbfs_prefix(filepath) + + protocol, path = get_protocol_and_path(filepath) + + # Special handling for Databricks paths + if self.protocol == "dbfs": + # Ensure dbfs:/ format for Spark + return f"dbfs:/{path}" + + # Map to Spark protocols + spark_protocols = { + "s3": "s3a", # Spark prefers s3a:// + "gs": "gs", + "abfs": "abfs", + "wasbs": "wasbs", + "file": "file", + "": "file", + } + + spark_protocol = spark_protocols.get(protocol, protocol) + + # Handle local paths + if spark_protocol == "file": + # Ensure absolute path for local files + if not path.startswith("/"): + path = f"/{path}" + return f"file://{path}" + elif not spark_protocol: + return path + else: + return f"{spark_protocol}://{path}" + + def _get_spark(self) -> SparkSession: + """Get Spark session with support for Spark Connect and Databricks Connect.""" + try: + from pyspark.sql import SparkSession # noqa: PLC0415 + except ImportError as e: + # Detect environment and provide specific help + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + msg = ( + "Cannot import PySpark on Databricks. This is usually a " + "databricks-connect conflict. Try:\n" + " pip uninstall pyspark\n" + " pip install databricks-connect" + ) + elif "EMR_RELEASE_LABEL" in os.environ: + msg = "PySpark should be pre-installed on EMR. Check your cluster configuration." + else: + msg = ( + "PySpark not installed. Install based on your environment:\n" + " Local: pip install 'kedro-datasets[spark-local]'\n" + " Databricks: Use pre-installed Spark or databricks-connect\n" + " Spark Connect: pip install 'kedro-datasets[spark-connect]'\n" + " Cloud: Check your platform's Spark setup" + ) + raise ImportError(msg) from e + + # Try Databricks Connect first (for remote development) + if "DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ: + try: + # Databricks Connect configuration + logger.debug("Attempting to use Databricks Connect") + builder = SparkSession.builder + builder.remote( + f"sc://{os.environ['DATABRICKS_HOST']}:443/;token={os.environ['DATABRICKS_TOKEN']}" + ) + return builder.getOrCreate() + except Exception as e: + logger.debug(f"Databricks Connect failed, falling back: {e}") + + # Try Spark Connect (Spark 3.4+) + if spark_remote := os.environ.get("SPARK_REMOTE"): + try: + logger.debug(f"Using Spark Connect: {spark_remote}") + return SparkSession.builder.remote(spark_remote).getOrCreate() + except Exception as e: + logger.debug(f"Spark Connect failed, falling back: {e}") + + # Fall back to classic Spark session + logger.debug("Using classic Spark session") + return SparkSession.builder.getOrCreate() + + @staticmethod + def _load_schema_from_file(schema: dict[str, Any] | None) -> StructType | None: + """Load schema from file if provided.""" + if schema is None: + return None + + if not isinstance(schema, dict): + # Assume it's already a StructType + return schema + + filepath = schema.get("filepath") + if not filepath: + raise DatasetError( + "Schema dict must have 'filepath' attribute. " + "Please provide a path to a JSON-serialised 'pyspark.sql.types.StructType'." + ) + + try: + import fsspec # noqa: PLC0415 + from pyspark.sql.types import StructType # noqa: PLC0415 + except ImportError as e: + if "pyspark" in str(e): + raise ImportError("PySpark required to process schema") from e + raise ImportError("fsspec required for schema loading") from e + + protocol, path = get_protocol_and_path(filepath) + fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) + + try: + with fs.open(path, "r") as f: + schema_json = json.load(f) + return StructType.fromJson(schema_json) + except Exception as e: + raise DatasetError( + f"Failed to load schema from {filepath}. " + f"Ensure it contains valid JSON-serialised StructType." + ) from e + + def load(self) -> DataFrame: + """Load data using Spark""" + load_path = self._get_load_path() + spark_load_path = self._to_spark_path(str(load_path)) + + spark = self._get_spark() + + reader = spark.read + if self._schema: + reader = reader.schema(self._schema) + + return ( + reader.format(self.file_format) + .options(**self.load_args) + .load(spark_load_path) + ) + + def save(self, data: DataFrame) -> None: + """Save data using Spark""" + save_path = self._get_save_path() + spark_save_path = self._to_spark_path(str(save_path)) + + # Prepare writer + writer = data.write + + # Apply mode if specified + mode = self.save_args.pop("mode", None) + if mode: + writer = writer.mode(mode) + + # Apply partitioning if specified + partition_by = self.save_args.pop("partitionBy", None) + if partition_by: + writer = writer.partitionBy(partition_by) + + # Save with format and options + writer.format(self.file_format).options(**self.save_args).save(spark_save_path) + + # Restore save_args for potential reuse + if mode: + self.save_args["mode"] = mode + if partition_by: + self.save_args["partitionBy"] = partition_by + + def _exists(self) -> bool: + """Check existence using Spark read attempt for better accuracy.""" + load_path = self._get_load_path() + spark_load_path = self._to_spark_path(str(load_path)) + + try: + spark = self._get_spark() + # Try to read the metadata without loading data + spark.read.format(self.file_format).load(spark_load_path).schema + return True + except Exception as e: + # Check for specific error messages indicating non-existence + error_msg = str(e).lower() + if any( + msg in error_msg + for msg in [ + "path does not exist", + "file not found", + "is not a delta table", + "no such file", + ] + ): + return False + # Re-raise for unexpected errors + logger.warning(f"Error checking existence of {spark_load_path}: {e}") + raise + + def _validate_delta_format(self): + """Validate Delta-specific configurations""" + if self.file_format == "delta": + mode = self.save_args.get("mode") + supported = {"append", "overwrite", "error", "errorifexists", "ignore"} + if mode and mode not in supported: + raise DatasetError( + f"Delta format doesn't support mode '{mode}'. " + f"Use one of {supported} or DeltaTableDataset for advanced operations." + ) + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._spark_path, + "file_format": self.file_format, + "load_args": self.load_args, + "save_args": self.save_args, + "version": self._version, + "protocol": self.protocol, + } diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 7c5cd6e61..b6734bccb 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -18,15 +18,45 @@ dynamic = ["readme", "version"] [project.optional-dependencies] pandas-base = ["pandas>=1.3, <3.0"] -spark-base = ["pyspark>=2.2, <4.0"] -hdfs-base = ["hdfs>=2.5.8, <3.0"] -s3fs-base = ["s3fs>=2021.4"] polars-base = ["polars>=0.18.0"] plotly-base = ["plotly>=4.8.0, <6.0"] delta-base = ["delta-spark>=1.0, <4.0"] networkx-base = ["networkx~=3.4"] +hdfs-base = ["hdfs>=2.5.8, <3.0"] +s3fs-base = ["s3fs>=2021.4"] -# Individual Datasets +# Spark dependencies +spark-core = [] # No dependencies +spark-local = ["pyspark>=2.2,<4.0"] # For local development +spark-databricks = [] # Uses Databricks runtime Spark +spark-emr = [] # Uses EMR runtime Spark + +# Filesystem specific packages for Spark +spark-s3 = ["s3fs>=2021.4"] +spark-gcs = ["gcsfs>=2023.1, <2023.7"] +spark-azure = ["adlfs>=2023.1"] +spark-hdfs = ["pyarrow>=7.0"] # PyArrow includes HDFS support + +# Convenience bundles +spark = ["kedro-datasets[spark-local,spark-s3]"] # Most common setup +spark-cloud = ["kedro-datasets[spark-s3,spark-gcs,spark-azure]"] # All cloud filesystems + +# Individual Spark datasets +spark-deltatabledataset = ["kedro-datasets[spark-core,delta-base]"] +spark-sparkdataset = ["kedro-datasets[spark-core]"] +spark-sparkhivedataset = ["kedro-datasets[spark-core]"] +spark-sparkjdbcdataset = ["kedro-datasets[spark-core]"] +spark-sparkstreamingdataset = ["kedro-datasets[spark-core]"] + +spark-all = [ + """kedro-datasets[spark-deltatabledataset,\ + spark-sparkdataset,\ + spark-sparkhivedataset,\ + spark-sparkjdbcdataset,\ + spark-sparkstreamingdataset,\ + spark-local,\ + spark-cloud]""" +] api-apidataset = ["requests~=2.20"] api = ["kedro-datasets[api-apidataset]"] @@ -37,8 +67,9 @@ dask-csvdataset = ["dask[dataframe]>=2021.10"] dask-parquetdataset = ["dask[complete]>=2021.10", "triad>=0.6.7, <1.0"] dask = ["kedro-datasets[dask-parquetdataset, dask-csvdataset]"] -databricks-managedtabledataset = ["kedro-datasets[hdfs-base,s3fs-base]"] -databricks = ["kedro-datasets[databricks-managedtabledataset]"] +databricks-managedtabledataset = ["kedro-datasets[spark-core,spark-s3]"] +databricks-externaltabledataset = ["kedro-datasets[spark-core,spark-s3]"] +databricks = ["kedro-datasets[databricks-managedtabledataset,databricks-externaltabledataset]"] geopandas-genericdataset = ["geopandas>=0.8.0, <2.0", "fiona>=1.8, <2.0"] geopandas = ["kedro-datasets[geopandas-genericdataset]"] @@ -150,19 +181,6 @@ redis = ["kedro-datasets[redis-pickledataset]"] snowflake-snowparktabledataset = ["snowflake-snowpark-python>=1.23"] snowflake = ["kedro-datasets[snowflake-snowparktabledataset]"] -spark-deltatabledataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base,delta-base]"] -spark-sparkdataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base]"] -spark-sparkhivedataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base]"] -spark-sparkjdbcdataset = ["kedro-datasets[spark-base]"] -spark-sparkstreamingdataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base]"] -spark = [ - """kedro-datasets[spark-deltatabledataset,\ - spark-sparkdataset,\ - spark-sparkhivedataset,\ - spark-sparkjdbcdataset,\ - spark-sparkstreamingdataset]""" -] - svmlight-svmlightdataset = ["scikit-learn>=1.0.2", "scipy>=1.7.3"] svmlight = ["kedro-datasets[svmlight-svmlightdataset]"] @@ -178,7 +196,6 @@ yaml = ["kedro-datasets[yaml-yamldataset]"] # Experimental Datasets darts-torch-model-dataset = ["u8darts-all"] darts = ["kedro-datasets[darts-torch-model-dataset]"] -databricks-externaltabledataset = ["kedro-datasets[hdfs-base,s3fs-base]"] langchain-chatopenaidataset = ["langchain-openai~=0.1.7"] langchain-openaiembeddingsdataset = ["langchain-openai~=0.1.7"] langchain-chatanthropicdataset = ["langchain-anthropic~=0.1.13", "langchain-community~=0.2.0"] diff --git a/kedro-datasets/tests/spark/test_spark_dataset_v2.py b/kedro-datasets/tests/spark/test_spark_dataset_v2.py new file mode 100644 index 000000000..dc1d6e4f4 --- /dev/null +++ b/kedro-datasets/tests/spark/test_spark_dataset_v2.py @@ -0,0 +1,548 @@ +"""Tests for SparkDatasetV2.""" + +import os +import sys +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import pandas as pd +import pytest +from kedro.io import DataCatalog, Version +from kedro.io.core import DatasetError, generate_timestamp +from kedro.pipeline import node, pipeline +from kedro.runner import ParallelRunner, SequentialRunner +from pyspark.sql import SparkSession +from pyspark.sql.functions import col +from pyspark.sql.types import ( + FloatType, + IntegerType, + StringType, + StructField, + StructType, +) + +from kedro_datasets.pandas import CSVDataset, ParquetDataset +from kedro_datasets.spark import SparkDatasetV2 + +# Test constants +FILENAME = "test.parquet" +BUCKET_NAME = "test_bucket" +SCHEMA_FILE_NAME = "schema.json" + + +@pytest.fixture(scope="module") +def spark_session(): + """Create a Spark session for testing.""" + spark = ( + SparkSession.builder.master("local[2]") + .appName("TestSparkDatasetV2") + .getOrCreate() + ) + yield spark + spark.stop() + + +@pytest.fixture +def sample_spark_df(spark_session): + """Create a sample Spark DataFrame.""" + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)] + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + """Create a sample pandas DataFrame.""" + return pd.DataFrame( + { + "name": ["Alice", "Bob", "Charlie"], + "age": [30, 25, 35], + } + ) + + +@pytest.fixture +def sample_schema(): + """Create a sample schema.""" + return StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", FloatType(), True), + ] + ) + + +@pytest.fixture +def version(): + """Create a version for testing.""" + return Version(None, generate_timestamp()) + + +class TestSparkDatasetV2Basic: + """Test basic functionality of SparkDatasetV2.""" + + def test_load_save_parquet(self, tmp_path, sample_spark_df): + """Test basic load and save with parquet format.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + + # Save + dataset.save(sample_spark_df) + assert Path(filepath).exists() + + # Load + loaded_df = dataset.load() + assert loaded_df.count() == sample_spark_df.count() + assert set(loaded_df.columns) == set(sample_spark_df.columns) + + def test_load_save_csv(self, tmp_path, sample_spark_df): + """Test load and save with CSV format.""" + filepath = str(tmp_path / "test.csv") + dataset = SparkDatasetV2( + filepath=filepath, + file_format="csv", + save_args={"header": True}, + load_args={"header": True, "inferSchema": True}, + ) + + dataset.save(sample_spark_df) + loaded_df = dataset.load() + + assert loaded_df.count() == sample_spark_df.count() + assert set(loaded_df.columns) == set(sample_spark_df.columns) + + def test_load_save_json(self, tmp_path, sample_spark_df): + """Test load and save with JSON format.""" + filepath = str(tmp_path / "test.json") + dataset = SparkDatasetV2(filepath=filepath, file_format="json") + + dataset.save(sample_spark_df) + loaded_df = dataset.load() + + assert loaded_df.count() == sample_spark_df.count() + + def test_save_modes(self, tmp_path, sample_spark_df): + """Test different save modes.""" + filepath = str(tmp_path / "test.parquet") + + # Test overwrite mode + dataset = SparkDatasetV2(filepath=filepath, save_args={"mode": "overwrite"}) + dataset.save(sample_spark_df) + dataset.save(sample_spark_df) # Should not fail + + # Test append mode + dataset_append = SparkDatasetV2( + filepath=str(tmp_path / "test_append.parquet"), save_args={"mode": "append"} + ) + dataset_append.save(sample_spark_df) + dataset_append.save(sample_spark_df) + loaded = dataset_append.load() + assert loaded.count() == sample_spark_df.count() * 2 + + def test_partitioning(self, tmp_path, sample_spark_df): + """Test data partitioning.""" + filepath = str(tmp_path / "test_partitioned.parquet") + dataset = SparkDatasetV2(filepath=filepath, save_args={"partitionBy": ["name"]}) + + dataset.save(sample_spark_df) + + # Check partition directories exist + base_path = Path(filepath) + partitions = [d for d in base_path.iterdir() if d.is_dir()] + assert len(partitions) > 0 + assert any("name=" in d.name for d in partitions) + + def test_exists(self, tmp_path, sample_spark_df): + """Test exists functionality.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + + assert not dataset.exists() + dataset.save(sample_spark_df) + assert dataset.exists() + + def test_describe(self, tmp_path): + """Test _describe method.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2( + filepath=filepath, + file_format="parquet", + load_args={"mergeSchema": True}, + save_args={"compression": "snappy"}, + ) + + description = dataset._describe() + assert description["file_format"] == "parquet" + assert description["load_args"] == {"mergeSchema": True} + assert description["save_args"] == {"compression": "snappy"} + + def test_str_representation(self, tmp_path): + """Test string representation.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + + assert "SparkDatasetV2" in str(dataset) + assert filepath in str(dataset) + + +class TestSparkDatasetV2Schema: + """Test schema handling in SparkDatasetV2.""" + + def test_schema_from_dict(self, tmp_path, sample_pandas_df, sample_schema): + """Test loading schema from dict.""" + # Save schema to file + schema_path = tmp_path / "schema.json" + schema_path.write_text(sample_schema.json()) + + # Save CSV data + csv_path = str(tmp_path / "test.csv") + sample_pandas_df.to_csv(csv_path, index=False) + + # Load with schema + dataset = SparkDatasetV2( + filepath=csv_path, + file_format="csv", + load_args={"header": True, "schema": {"filepath": str(schema_path)}}, + ) + + loaded_df = dataset.load() + assert loaded_df.schema == sample_schema + + def test_schema_invalid_filepath(self, tmp_path): + """Test error when schema filepath is invalid.""" + csv_path = str(tmp_path / "test.csv") + schema_path = tmp_path / "bad_schema.json" + schema_path.write_text("invalid json {") + + with pytest.raises(DatasetError, match="Failed to load schema"): + SparkDatasetV2( + filepath=csv_path, + file_format="csv", + load_args={"schema": {"filepath": str(schema_path)}}, + ) + + def test_schema_missing_filepath(self, tmp_path): + """Test error when schema dict missing filepath.""" + csv_path = str(tmp_path / "test.csv") + + with pytest.raises(DatasetError, match="Schema dict must have 'filepath'"): + SparkDatasetV2( + filepath=csv_path, file_format="csv", load_args={"schema": {}} + ) + + +class TestSparkDatasetV2PathHandling: + """Test path handling in SparkDatasetV2.""" + + def test_local_path(self, tmp_path): + """Test local path handling.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + + assert dataset.protocol == "" + assert dataset._spark_path == filepath + + def test_s3_path_normalization(self): + """Test S3 path normalization to s3a://.""" + # All S3 variants should normalize to s3a:// + for prefix in ["s3://", "s3n://", "s3a://"]: + filepath = f"{prefix}bucket/path/data.parquet" + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path.startswith("s3a://") + + @pytest.mark.skipif( + "DATABRICKS_RUNTIME_VERSION" not in os.environ, + reason="Not running on Databricks", + ) + def test_dbfs_path_on_databricks(self): + """Test DBFS path handling on Databricks.""" + filepath = "/dbfs/path/to/data.parquet" + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path == "dbfs:/path/to/data.parquet" + + def test_dbfs_path_not_on_databricks(self, monkeypatch): + """Test DBFS path handling when not on Databricks.""" + # Ensure we're not on Databricks + monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) + + filepath = "/dbfs/path/to/data.parquet" + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path == filepath + + def test_other_protocols(self): + """Test other protocol handling.""" + protocols = { + "gs://bucket/path": "gs://", + "abfs://container@account.dfs.core.windows.net/path": "abfs://", + } + + for filepath, expected_prefix in protocols.items(): + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path.startswith(expected_prefix) + + +class TestSparkDatasetV2ErrorMessages: + """Test improved error messages in SparkDatasetV2.""" + + def test_missing_s3fs_error(self, mocker): + """Test helpful error for missing s3fs.""" + import_error = ImportError("No module named 's3fs'") + mocker.patch("fsspec.filesystem", side_effect=import_error) + + with pytest.raises( + ImportError, match="pip install 'kedro-datasets\\[spark-s3\\]'" + ): + SparkDatasetV2(filepath="s3://bucket/data.parquet") + + def test_missing_gcsfs_error(self, mocker): + """Test helpful error for missing gcsfs.""" + import_error = ImportError("No module named 'gcsfs'") + mocker.patch("fsspec.filesystem", side_effect=import_error) + + with pytest.raises(ImportError, match="pip install gcsfs"): + SparkDatasetV2(filepath="gs://bucket/data.parquet") + + def test_missing_pyspark_databricks(self, mocker, monkeypatch): + """Test helpful error for PySpark on Databricks.""" + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + + dataset = SparkDatasetV2(filepath="test.parquet") + mocker.patch.object( + dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") + ) + + with pytest.raises(ImportError, match="databricks-connect"): + dataset.load() + + def test_missing_pyspark_emr(self, mocker, monkeypatch): + """Test helpful error for PySpark on EMR.""" + monkeypatch.setenv("EMR_RELEASE_LABEL", "emr-7.0.0") + + dataset = SparkDatasetV2(filepath="test.parquet") + mocker.patch.object( + dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") + ) + + with pytest.raises(ImportError, match="should be pre-installed on EMR"): + dataset.load() + + def test_missing_pyspark_local(self, mocker, monkeypatch): + """Test helpful error for PySpark locally.""" + monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) + monkeypatch.delenv("EMR_RELEASE_LABEL", raising=False) + + dataset = SparkDatasetV2(filepath="test.parquet") + mocker.patch.object( + dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") + ) + + with pytest.raises( + ImportError, match="pip install 'kedro-datasets\\[spark-local\\]'" + ): + dataset.load() + + +class TestSparkDatasetV2Delta: + """Test Delta format handling in SparkDatasetV2.""" + + @pytest.mark.parametrize("mode", ["merge", "update", "delete"]) + def test_delta_unsupported_modes(self, tmp_path, mode): + """Test that unsupported Delta modes raise errors.""" + filepath = str(tmp_path / "test.delta") + + with pytest.raises( + DatasetError, match=f"Delta format doesn't support mode '{mode}'" + ): + SparkDatasetV2( + filepath=filepath, file_format="delta", save_args={"mode": mode} + ) + + @pytest.mark.parametrize( + "mode", ["append", "overwrite", "error", "errorifexists", "ignore"] + ) + def test_delta_supported_modes(self, tmp_path, mode): + """Test that supported Delta modes work.""" + filepath = str(tmp_path / "test.delta") + + # Should not raise + dataset = SparkDatasetV2( + filepath=filepath, file_format="delta", save_args={"mode": mode} + ) + assert dataset.file_format == "delta" + + +class TestSparkDatasetV2Versioning: + """Test versioning functionality in SparkDatasetV2.""" + + def test_versioned_save_and_load(self, tmp_path, sample_spark_df, version): + """Test versioned save and load.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath, version=version) + + # Save versioned + dataset.save(sample_spark_df) + + # Check versioned path exists + versioned_path = tmp_path / "test.parquet" / version.save / "test.parquet" + assert versioned_path.exists() + + # Load versioned + loaded_df = dataset.load() + assert loaded_df.count() == sample_spark_df.count() + + def test_no_version_error(self, tmp_path): + """Test error when no versions exist.""" + filepath = str(tmp_path / "test.parquet") + version = Version(None, None) # Load latest + dataset = SparkDatasetV2(filepath=filepath, version=version) + + with pytest.raises(DatasetError, match="Did not find any versions"): + dataset.load() + + def test_version_str_representation(self, tmp_path, version): + """Test version in string representation.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath, version=version) + + assert "version=" in str(dataset._describe()) + + +class TestSparkDatasetV2Integration: + """Integration tests for SparkDatasetV2.""" + + def test_parallel_runner_restriction(self, tmp_path, sample_spark_df): + """Test that ParallelRunner is restricted.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + dataset.save(sample_spark_df) + + catalog = DataCatalog({"spark_data": dataset}) + test_pipeline = pipeline([node(lambda x: x, "spark_data", "output")]) + + with pytest.raises(AttributeError, match="cannot be used with multiprocessing"): + ParallelRunner().run(test_pipeline, catalog) + + def test_sequential_runner(self, tmp_path, sample_spark_df): + """Test that SequentialRunner works.""" + filepath_in = str(tmp_path / "input.parquet") + filepath_out = str(tmp_path / "output.parquet") + + dataset_in = SparkDatasetV2(filepath=filepath_in) + dataset_out = SparkDatasetV2(filepath=filepath_out) + dataset_in.save(sample_spark_df) + + catalog = DataCatalog({"input": dataset_in, "output": dataset_out}) + + test_pipeline = pipeline([node(lambda x: x, "input", "output")]) + SequentialRunner().run(test_pipeline, catalog) + + assert Path(filepath_out).exists() + + def test_interop_with_pandas_dataset(self, tmp_path, sample_pandas_df): + """Test interoperability with pandas datasets.""" + # Save with pandas + pandas_path = str(tmp_path / "pandas.parquet") + pandas_dataset = ParquetDataset(filepath=pandas_path) + pandas_dataset.save(sample_pandas_df) + + # Load with SparkDatasetV2 + spark_dataset = SparkDatasetV2(filepath=pandas_path) + spark_df = spark_dataset.load() + + assert spark_df.count() == len(sample_pandas_df) + assert set(spark_df.columns) == set(sample_pandas_df.columns) + + +class TestSparkDatasetV2Compatibility: + """Test compatibility between V1 and V2.""" + + def test_v1_v2_coexistence(self, tmp_path, sample_spark_df): + """Test that V1 and V2 can coexist in the same catalog.""" + from kedro_datasets.spark import SparkDataset # noqa: PLC0415 + + filepath_v1 = str(tmp_path / "v1.parquet") + filepath_v2 = str(tmp_path / "v2.parquet") + + dataset_v1 = SparkDataset(filepath=filepath_v1) + dataset_v2 = SparkDatasetV2(filepath=filepath_v2) + + catalog = DataCatalog( + { + "v1_data": dataset_v1, + "v2_data": dataset_v2, + } + ) + + # Both should work + dataset_v1.save(sample_spark_df) + dataset_v2.save(sample_spark_df) + + assert catalog.exists("v1_data") + assert catalog.exists("v2_data") + + def test_v2_reads_v1_data(self, tmp_path, sample_spark_df): + """Test that V2 can read data saved by V1.""" + from kedro_datasets.spark import SparkDataset # noqa: PLC0415 + + filepath = str(tmp_path / "shared.parquet") + + # V1 saves + dataset_v1 = SparkDataset(filepath=filepath) + dataset_v1.save(sample_spark_df) + + # V2 loads + dataset_v2 = SparkDatasetV2(filepath=filepath) + loaded_df = dataset_v2.load() + + assert loaded_df.count() == sample_spark_df.count() + + +# Fixtures for mocking cloud storage +@pytest.fixture +def mock_s3_filesystem(): + """Mock S3 filesystem.""" + with patch("fsspec.filesystem") as mock_fs: + mock_filesystem = MagicMock() + mock_filesystem.exists.return_value = True + mock_filesystem.glob.return_value = [] + mock_fs.return_value = mock_filesystem + yield mock_filesystem + + +class TestSparkDatasetV2CloudStorage: + """Test cloud storage handling.""" + + def test_s3_credentials(self, mock_s3_filesystem): + """Test S3 with credentials.""" + dataset = SparkDatasetV2( + filepath="s3://bucket/data.parquet", + credentials={"key": "test_key", "secret": "test_secret"}, + ) + + # Verify s3a:// normalization + assert dataset._spark_path.startswith("s3a://") + + def test_gcs_handling(self, mock_s3_filesystem): + """Test GCS handling.""" + dataset = SparkDatasetV2( + filepath="gs://bucket/data.parquet", + credentials={"token": "path/to/token.json"}, + ) + + assert dataset._spark_path.startswith("gs://") + + def test_azure_handling(self, mock_s3_filesystem): + """Test Azure Blob Storage handling.""" + dataset = SparkDatasetV2( + filepath="abfs://container@account.dfs.core.windows.net/data.parquet", + credentials={"account_key": "test_key"}, + ) + + assert dataset._spark_path.startswith("abfs://")