From 6623c7378d60efd01a04702ee938ee5a5574accc Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Thu, 11 Sep 2025 16:04:24 +0100 Subject: [PATCH 01/17] rework spark in pyproject.toml Signed-off-by: Sajid Alam --- kedro-datasets/pyproject.toml | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index bfaf29e8a..f0440d752 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -18,7 +18,12 @@ dynamic = ["readme", "version"] [project.optional-dependencies] pandas-base = ["pandas>=1.3, <3.0"] -spark-base = ["pyspark>=2.2, <4.0"] +spark-core = [] +spark-local = ["kedro-datasets[spark-core]", "pyspark>=2.2,<4.0"] +spark-s3 = ["kedro-datasets[spark-core,s3fs-base]"] +spark-hdfs = ["kedro-datasets[spark-core,hdfs-base]"] +spark-databricks = ["kedro-datasets[spark-core]"] +spark-full = ["kedro-datasets[spark-local,spark-s3,spark-hdfs,delta-base]"] hdfs-base = ["hdfs>=2.5.8, <3.0"] s3fs-base = ["s3fs>=2021.4"] polars-base = ["polars>=0.18.0"] @@ -37,7 +42,7 @@ dask-csvdataset = ["dask[dataframe]>=2021.10"] dask-parquetdataset = ["dask[complete]>=2021.10", "triad>=0.6.7, <1.0"] dask = ["kedro-datasets[dask-parquetdataset, dask-csvdataset]"] -databricks-managedtabledataset = ["kedro-datasets[hdfs-base,s3fs-base]"] +databricks-managedtabledataset = ["kedro-datasets[spark-core,s3fs-base]"] databricks = ["kedro-datasets[databricks-managedtabledataset]"] geopandas-genericdataset = ["geopandas>=0.8.0, <2.0", "fiona>=1.8, <2.0"] @@ -150,11 +155,11 @@ redis = ["kedro-datasets[redis-pickledataset]"] snowflake-snowparktabledataset = ["snowflake-snowpark-python>=1.23"] snowflake = ["kedro-datasets[snowflake-snowparktabledataset]"] -spark-deltatabledataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base,delta-base]"] -spark-sparkdataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base]"] -spark-sparkhivedataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base]"] -spark-sparkjdbcdataset = ["kedro-datasets[spark-base]"] -spark-sparkstreamingdataset = ["kedro-datasets[spark-base,hdfs-base,s3fs-base]"] +spark-deltatabledataset = ["kedro-datasets[spark-local,s3fs-base,delta-base]"] +spark-sparkdataset = ["kedro-datasets[spark-local,spark-s3]"] +spark-sparkhivedataset = ["kedro-datasets[spark-local,s3fs-base]"] +spark-sparkjdbcdataset = ["kedro-datasets[spark-local]"] +spark-sparkstreamingdataset = ["kedro-datasets[spark-local,s3fs-base]"] spark = [ """kedro-datasets[spark-deltatabledataset,\ spark-sparkdataset,\ From e28b980456e5bf9f6ed24ef2bfd243ff5352e9f7 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 22 Sep 2025 12:29:41 +0100 Subject: [PATCH 02/17] Update pyproject.toml Signed-off-by: Sajid Alam --- kedro-datasets/pyproject.toml | 62 +++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 9414d7add..4357edcdb 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -18,20 +18,45 @@ dynamic = ["readme", "version"] [project.optional-dependencies] pandas-base = ["pandas>=1.3, <3.0"] -spark-core = [] -spark-local = ["kedro-datasets[spark-core]", "pyspark>=2.2,<4.0"] -spark-s3 = ["kedro-datasets[spark-core,s3fs-base]"] -spark-hdfs = ["kedro-datasets[spark-core,hdfs-base]"] -spark-databricks = ["kedro-datasets[spark-core]"] -spark-full = ["kedro-datasets[spark-local,spark-s3,spark-hdfs,delta-base]"] -hdfs-base = ["hdfs>=2.5.8, <3.0"] -s3fs-base = ["s3fs>=2021.4"] polars-base = ["polars>=0.18.0"] plotly-base = ["plotly>=4.8.0, <6.0"] delta-base = ["delta-spark>=1.0, <4.0"] networkx-base = ["networkx~=3.4"] +hdfs-base = ["hdfs>=2.5.8, <3.0"] +s3fs-base = ["s3fs>=2021.4"] -# Individual Datasets +# Spark dependencies +spark-core = [] # No dependencies +spark-local = ["pyspark>=2.2,<4.0"] # For local development +spark-databricks = [] # Uses Databricks runtime Spark +spark-emr = [] # Uses EMR runtime Spark + +# Filesystem specific packages for Spark +spark-s3 = ["s3fs>=2021.4"] +spark-gcs = ["gcsfs>=2023.1, <2023.7"] +spark-azure = ["adlfs>=2023.1"] +spark-hdfs = ["pyarrow>=7.0"] # PyArrow includes HDFS support + +# Convenience bundles +spark = ["kedro-datasets[spark-local,spark-s3]"] # Most common setup +spark-cloud = ["kedro-datasets[spark-s3,spark-gcs,spark-azure]"] # All cloud filesystems + +# Individual Spark datasets +spark-deltatabledataset = ["kedro-datasets[spark-core,delta-base]"] +spark-sparkdataset = ["kedro-datasets[spark-core]"] +spark-sparkhivedataset = ["kedro-datasets[spark-core]"] +spark-sparkjdbcdataset = ["kedro-datasets[spark-core]"] +spark-sparkstreamingdataset = ["kedro-datasets[spark-core]"] + +spark-all = [ + """kedro-datasets[spark-deltatabledataset,\ + spark-sparkdataset,\ + spark-sparkhivedataset,\ + spark-sparkjdbcdataset,\ + spark-sparkstreamingdataset,\ + spark-local,\ + spark-cloud]""" +] api-apidataset = ["requests~=2.20"] api = ["kedro-datasets[api-apidataset]"] @@ -42,8 +67,9 @@ dask-csvdataset = ["dask[dataframe]>=2021.10"] dask-parquetdataset = ["dask[complete]>=2021.10", "triad>=0.6.7, <1.0"] dask = ["kedro-datasets[dask-parquetdataset, dask-csvdataset]"] -databricks-managedtabledataset = ["kedro-datasets[spark-core,s3fs-base]"] -databricks = ["kedro-datasets[databricks-managedtabledataset]"] +databricks-managedtabledataset = ["kedro-datasets[spark-core,spark-s3]"] +databricks-externaltabledataset = ["kedro-datasets[spark-core,spark-s3]"] +databricks = ["kedro-datasets[databricks-managedtabledataset,databricks-externaltabledataset]"] geopandas-genericdataset = ["geopandas>=0.8.0, <2.0", "fiona>=1.8, <2.0"] geopandas = ["kedro-datasets[geopandas-genericdataset]"] @@ -155,19 +181,6 @@ redis = ["kedro-datasets[redis-pickledataset]"] snowflake-snowparktabledataset = ["snowflake-snowpark-python>=1.23"] snowflake = ["kedro-datasets[snowflake-snowparktabledataset]"] -spark-deltatabledataset = ["kedro-datasets[spark-local,s3fs-base,delta-base]"] -spark-sparkdataset = ["kedro-datasets[spark-local,spark-s3]"] -spark-sparkhivedataset = ["kedro-datasets[spark-local,s3fs-base]"] -spark-sparkjdbcdataset = ["kedro-datasets[spark-local]"] -spark-sparkstreamingdataset = ["kedro-datasets[spark-local,s3fs-base]"] -spark = [ - """kedro-datasets[spark-deltatabledataset,\ - spark-sparkdataset,\ - spark-sparkhivedataset,\ - spark-sparkjdbcdataset,\ - spark-sparkstreamingdataset]""" -] - svmlight-svmlightdataset = ["scikit-learn>=1.0.2", "scipy>=1.7.3"] svmlight = ["kedro-datasets[svmlight-svmlightdataset]"] @@ -183,7 +196,6 @@ yaml = ["kedro-datasets[yaml-yamldataset]"] # Experimental Datasets darts-torch-model-dataset = ["u8darts-all"] darts = ["kedro-datasets[darts-torch-model-dataset]"] -databricks-externaltabledataset = ["kedro-datasets[hdfs-base,s3fs-base]"] langchain-chatopenaidataset = ["langchain-openai~=0.1.7"] langchain-openaiembeddingsdataset = ["langchain-openai~=0.1.7"] langchain-chatanthropicdataset = ["langchain-anthropic~=0.1.13", "langchain-community~=0.2.0"] From 27f9bd364950419d013f337fe611f557bf3a572f Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 22 Sep 2025 12:55:11 +0100 Subject: [PATCH 03/17] Update spark_dataset.py Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset.py | 418 ++++++++---------- 1 file changed, 187 insertions(+), 231 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index 0cd84f570..f89b22925 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -3,88 +3,26 @@ """ from __future__ import annotations - -import json +from typing import TYPE_CHECKING, Any +import os import logging -from copy import deepcopy -from fnmatch import fnmatch -from functools import partial from pathlib import PurePosixPath -from typing import Any -from warnings import warn -import fsspec -from hdfs import HdfsError, InsecureClient from kedro.io.core import ( - CLOUD_PROTOCOLS, AbstractVersionedDataset, DatasetError, Version, - get_filepath_str, get_protocol_and_path, ) -from pyspark.sql import DataFrame -from pyspark.sql.types import StructType -from pyspark.sql.utils import AnalysisException -from s3fs import S3FileSystem - -from kedro_datasets._utils.databricks_utils import ( - dbfs_exists, - dbfs_glob, - deployed_on_databricks, - get_dbutils, - parse_glob_pattern, - split_filepath, - strip_dbfs_prefix, -) -from kedro_datasets._utils.spark_utils import get_spark - -logger = logging.getLogger(__name__) - - -class KedroHdfsInsecureClient(InsecureClient): - """Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists`` - and ``hdfs_glob`` methods required by ``SparkDataset``""" - def hdfs_exists(self, hdfs_path: str) -> bool: - """Determines whether given ``hdfs_path`` exists in HDFS. +if TYPE_CHECKING: + from pyspark.sql import DataFrame, SparkSession + from pyspark.sql.types import StructType - Args: - hdfs_path: Path to check. - - Returns: - True if ``hdfs_path`` exists in HDFS, False otherwise. - """ - return bool(self.status(hdfs_path, strict=False)) - - def hdfs_glob(self, pattern: str) -> list[str]: - """Perform a glob search in HDFS using the provided pattern. +logger = logging.getLogger(__name__) - Args: - pattern: Glob pattern to search for. - Returns: - List of HDFS paths that satisfy the glob pattern. - """ - prefix = parse_glob_pattern(pattern) or "/" - matched = set() - try: - for dpath, _, fnames in self.walk(prefix): - if fnmatch(dpath, pattern): - matched.add(dpath) - matched |= { - f"{dpath}/{fname}" - for fname in fnames - if fnmatch(f"{dpath}/{fname}", pattern) - } - except HdfsError: # pragma: no cover - # HdfsError is raised by `self.walk()` if prefix does not exist in HDFS. - # Ignore and return an empty list. - pass - return sorted(matched) - - -class SparkDataset(AbstractVersionedDataset[DataFrame, DataFrame]): +class SparkDataset(AbstractVersionedDataset): """``SparkDataset`` loads and saves Spark dataframes. Examples: @@ -157,183 +95,201 @@ def __init__( # noqa: PLR0913 version: Version | None = None, credentials: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, - ) -> None: - """Creates a new instance of ``SparkDataset``. - - Args: - filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks - specify ``filepath``s starting with ``/dbfs/``. - file_format: File format used during load and save - operations. These are formats supported by the running - SparkContext include parquet, csv, delta. For a list of supported - formats please refer to Apache Spark documentation at - https://spark.apache.org/docs/latest/sql-programming-guide.html - load_args: Load args passed to Spark DataFrameReader load method. - It is dependent on the selected file format. You can find - a list of read options for each supported format - in Spark DataFrame read documentation: - https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html - save_args: Save args passed to Spark DataFrame write options. - Similar to load_args this is dependent on the selected file - format. You can pass ``mode`` and ``partitionBy`` to specify - your overwrite mode and partitioning respectively. You can find - a list of options for each format in Spark DataFrame - write documentation: - https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html - version: If specified, should be an instance of - ``kedro.io.core.Version``. If its ``load`` attribute is - None, the latest version will be loaded. If its ``save`` - attribute is None, save version will be autogenerated. - credentials: Credentials to access the S3 bucket, such as - ``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``. - Optional keyword arguments passed to ``hdfs.client.InsecureClient`` - if ``filepath`` prefix is ``hdfs://``. Ignored otherwise. - metadata: Any arbitrary metadata. - This is ignored by Kedro, but may be consumed by users or external plugins. - """ - credentials = deepcopy(credentials) or {} - fs_prefix, filepath = split_filepath(filepath) - path = PurePosixPath(filepath) - exists_function = None - glob_function = None + ): + self.file_format = file_format + self.load_args = load_args or {} + self.save_args = save_args or {} + self.credentials = credentials or {} self.metadata = metadata - if ( - not (filepath.startswith("/dbfs") or filepath.startswith("/Volumes")) - and fs_prefix not in (protocol + "://" for protocol in CLOUD_PROTOCOLS) - and deployed_on_databricks() - ): - logger.warning( - "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix in the " - "filepath is a known source of error. You must add this prefix to %s", - filepath, - ) - if fs_prefix and fs_prefix in ("s3a://"): - _s3 = S3FileSystem(**credentials) - exists_function = _s3.exists - # Ensure cache is not used so latest version is retrieved correctly. - glob_function = partial(_s3.glob, refresh=True) - - elif fs_prefix == "hdfs://": - if version: - warn( - f"HDFS filesystem support for versioned {self.__class__.__name__} is " - f"in beta and uses 'hdfs.client.InsecureClient', please use with " - f"caution" - ) + # Parse filepath + self.protocol, self.path = get_protocol_and_path(filepath) - # default namenode address - credentials.setdefault("url", "http://localhost:9870") - credentials.setdefault("user", "hadoop") - - _hdfs_client = KedroHdfsInsecureClient(**credentials) - exists_function = _hdfs_client.hdfs_exists - glob_function = _hdfs_client.hdfs_glob # type: ignore - - elif filepath.startswith("/dbfs/"): - # dbfs add prefix to Spark path by default - # See https://github.com/kedro-org/kedro-plugins/issues/117 - dbutils = get_dbutils(get_spark()) - if dbutils: - glob_function = partial(dbfs_glob, dbutils=dbutils) - exists_function = partial(dbfs_exists, dbutils=dbutils) - else: - filesystem = fsspec.filesystem(fs_prefix.strip("://"), **credentials) - exists_function = filesystem.exists - glob_function = filesystem.glob + # Get filesystem for metadata operations (exists, glob) + self._fs = self._get_filesystem() + + # Store Spark compatible path for I/O + self._spark_path = self._to_spark_path(filepath) + + # Handle schema if provided + self._schema = self._process_schema(self.load_args.pop("schema", None)) super().__init__( - filepath=path, + filepath=PurePosixPath(self.path), version=version, - exists_function=exists_function, - glob_function=glob_function, + exists_function=self._fs.exists, + glob_function=self._fs.glob, ) - # Handle default load and save arguments - self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} - self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} - - # Handle schema load argument - self._schema = self._load_args.pop("schema", None) - if self._schema is not None: - if isinstance(self._schema, dict): - self._schema = self._load_schema_from_file(self._schema) - - self._file_format = file_format - self._fs_prefix = fs_prefix - self._handle_delta_format() - - @staticmethod - def _load_schema_from_file(schema: dict[str, Any]) -> StructType: - filepath = schema.get("filepath") - if not filepath: - raise DatasetError( - "Schema load argument does not specify a 'filepath' attribute. Please" - "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." - ) - - credentials = deepcopy(schema.get("credentials")) or {} - protocol, schema_path = get_protocol_and_path(filepath) - file_system = fsspec.filesystem(protocol, **credentials) - pure_posix_path = PurePosixPath(schema_path) - load_path = get_filepath_str(pure_posix_path, protocol) - - # Open schema file - with file_system.open(load_path) as fs_file: - try: - return StructType.fromJson(json.loads(fs_file.read())) - except Exception as exc: - raise DatasetError( - f"Contents of 'schema.filepath' ({schema_path}) are invalid. Please" - f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." - ) from exc + self._validate_delta_format() - def _describe(self) -> dict[str, Any]: - return { - "filepath": self._fs_prefix + str(self._filepath), - "file_format": self._file_format, - "load_args": self._load_args, - "save_args": self._save_args, - "version": self._version, + def _get_filesystem(self): + """Get fsspec filesystem with helpful errors for missing deps""" + try: + import fsspec + except ImportError: + raise ImportError("fsspec is required") + + # Normalise protocols + protocol_map = { + "s3a": "s3", "s3n": "s3", # Spark S3 variants + "dbfs": "file", # DBFS is mounted as local + "": "file", # Default to local } - def load(self) -> DataFrame: - load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - read_obj = get_spark().read + fsspec_protocol = protocol_map.get(self.protocol, self.protocol) + + try: + return fsspec.filesystem(fsspec_protocol, **self.credentials) + except ImportError as e: + # Provide targeted help + if "s3fs" in str(e): + msg = "s3fs not installed. Install with: pip install 'kedro-datasets[spark-s3]'" + elif "gcsfs" in str(e): + msg = "gcsfs not installed. Install with: pip install gcsfs" + elif "adlfs" in str(e): + msg = "adlfs not installed. Install with: pip install adlfs" + else: + msg = str(e) + raise ImportError(msg) from e + + def _to_spark_path(self, filepath: str) -> str: + """Convert to Spark-compatible path format""" + protocol, path = get_protocol_and_path(filepath) + + # Handle special cases + if filepath.startswith("/dbfs/"): + # Databricks: /dbfs/path -> dbfs:/path + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + return "dbfs:/" + filepath[6:] + return filepath + + # Map to Spark protocols + spark_protocols = { + "s3": "s3a", # Critical: Spark prefers s3a:// + "gs": "gs", + "abfs": "abfs", + "file": "", # Local paths don't need protocol + "": "", + } - # Pass schema if defined + spark_protocol = spark_protocols.get(protocol, protocol) + + if not spark_protocol: + return path + return f"{spark_protocol}://{path}" + + def _get_spark(self) -> "SparkSession": + """Lazy load Spark with environment specific guidance""" + try: + from pyspark.sql import SparkSession + return SparkSession.builder.getOrCreate() + except ImportError as e: + # Detect environment and provide specific help + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + msg = ( + "Cannot import PySpark on Databricks. This is usually a " + "databricks-connect conflict. Try:\n" + " pip uninstall pyspark\n" + " pip install databricks-connect" + ) + elif "EMR_RELEASE_LABEL" in os.environ: + msg = "PySpark should be pre-installed on EMR. Check your cluster configuration." + else: + msg = ( + "PySpark not installed. Install based on your environment:\n" + " Local: pip install 'kedro-datasets[spark-local]'\n" + " Databricks: Use pre-installed Spark or databricks-connect\n" + " Cloud: Check your platform's Spark setup" + ) + raise ImportError(msg) from e + + def _process_schema(self, schema: Any) -> Any: + """Process schema argument if provided""" + if schema is None: + return None + + if isinstance(schema, dict): + # Load from file + schema_path = schema.get("filepath") + if not schema_path: + raise DatasetError("Schema dict must have 'filepath'") + + # Use fsspec to load + import json + protocol, path = get_protocol_and_path(schema_path) + + try: + import fsspec + fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) + with fs.open(path, "r") as f: + schema_json = json.load(f) + + # Lazy import StructType + from pyspark.sql.types import StructType + return StructType.fromJson(schema_json) + except ImportError as e: + if "pyspark" in str(e): + raise ImportError("PySpark required to process schema") from e + raise + except Exception as e: + raise DatasetError(f"Failed to load schema from {schema_path}") from e + + return schema + + def load(self) -> "DataFrame": + """Load data using Spark""" + spark = self._get_spark() + + reader = spark.read if self._schema: - read_obj = read_obj.schema(self._schema) + reader = reader.schema(self._schema) - return read_obj.load(load_path, self._file_format, **self._load_args) + return reader.format(self.file_format).options(**self.load_args).load(self._spark_path) - def save(self, data: DataFrame) -> None: - save_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path())) - data.write.save(save_path, self._file_format, **self._save_args) + def save(self, data: "DataFrame") -> None: + """Save data using Spark""" + writer = data.write - def _exists(self) -> bool: - load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) + if mode := self.save_args.pop("mode", None): + writer = writer.mode(mode) + if partition_by := self.save_args.pop("partitionBy", None): + writer = writer.partitionBy(partition_by) + + writer.format(self.file_format).options(**self.save_args).save(self._spark_path) + + def _exists(self) -> bool: + """Existence check using fsspec""" try: - get_spark().read.load(load_path, self._file_format) - except AnalysisException as exception: - # `AnalysisException.desc` is deprecated with pyspark >= 3.4 - message = exception.desc if hasattr(exception, "desc") else str(exception) - if "Path does not exist:" in message or "is not a Delta table" in message: - return False - raise - return True - - def _handle_delta_format(self) -> None: - supported_modes = {"append", "overwrite", "error", "errorifexists", "ignore"} - write_mode = self._save_args.get("mode") - if ( - write_mode - and self._file_format == "delta" - and write_mode not in supported_modes - ): - raise DatasetError( - f"It is not possible to perform 'save()' for file format 'delta' " - f"with mode '{write_mode}' on 'SparkDataset'. " - f"Please use 'spark.DeltaTableDataset' instead." - ) + return self._fs.exists(self.path) + except Exception: + # Fallback to Spark check for special cases (e.g., Delta tables) + if self.file_format == "delta": + try: + spark = self._get_spark() + spark.read.format("delta").load(self._spark_path) + return True + except Exception: + return False + return False + + def _validate_delta_format(self): + """Validate Delta-specific configurations""" + if self.file_format == "delta": + mode = self.save_args.get("mode") + supported = {"append", "overwrite", "error", "errorifexists", "ignore"} + if mode and mode not in supported: + raise DatasetError( + f"Delta format doesn't support mode '{mode}'. " + f"Use one of {supported} or DeltaTableDataset for advanced operations." + ) + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._spark_path, + "file_format": self.file_format, + "load_args": self.load_args, + "save_args": self.save_args, + "version": self._version, + } From c82f440017696169adb838fb50d18228bf38ba7e Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 22 Sep 2025 13:07:01 +0100 Subject: [PATCH 04/17] Update spark_dataset.py Signed-off-by: Sajid Alam --- kedro-datasets/kedro_datasets/spark/spark_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index f89b22925..c7417c5a4 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -70,7 +70,7 @@ class SparkDataset(AbstractVersionedDataset): >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) >>> - >>> dataset = SparkDataset(filepath=tmp_path / "test_data") + >>> dataset = SparkDataset(filepath="tmp_path/test_data") >>> dataset.save(spark_df) >>> reloaded = dataset.load() >>> assert Row(name="Bob", age=12) in reloaded.take(4) @@ -155,6 +155,7 @@ def _get_filesystem(self): def _to_spark_path(self, filepath: str) -> str: """Convert to Spark-compatible path format""" + filepath = str(filepath) # Convert PosixPath to string protocol, path = get_protocol_and_path(filepath) # Handle special cases From 460cf31767e68e74d853aafd35440278d9d2f711 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 22 Sep 2025 13:25:50 +0100 Subject: [PATCH 05/17] lint Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index c7417c5a4..231e9e60d 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -3,10 +3,11 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any -import os + import logging +import os from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any from kedro.io.core import ( AbstractVersionedDataset, @@ -126,13 +127,14 @@ def __init__( # noqa: PLR0913 def _get_filesystem(self): """Get fsspec filesystem with helpful errors for missing deps""" try: - import fsspec + import fsspec # noqa: PLC0415 except ImportError: raise ImportError("fsspec is required") # Normalise protocols protocol_map = { - "s3a": "s3", "s3n": "s3", # Spark S3 variants + "s3a": "s3", + "s3n": "s3", # Spark S3 variants "dbfs": "file", # DBFS is mounted as local "": "file", # Default to local } @@ -180,10 +182,11 @@ def _to_spark_path(self, filepath: str) -> str: return path return f"{spark_protocol}://{path}" - def _get_spark(self) -> "SparkSession": + def _get_spark(self) -> SparkSession: """Lazy load Spark with environment specific guidance""" try: - from pyspark.sql import SparkSession + from pyspark.sql import SparkSession # noqa: PLC0415 + return SparkSession.builder.getOrCreate() except ImportError as e: # Detect environment and provide specific help @@ -217,17 +220,20 @@ def _process_schema(self, schema: Any) -> Any: raise DatasetError("Schema dict must have 'filepath'") # Use fsspec to load - import json + import json # noqa: PLC0415 + protocol, path = get_protocol_and_path(schema_path) try: - import fsspec + import fsspec # noqa: PLC0415 + fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) with fs.open(path, "r") as f: schema_json = json.load(f) # Lazy import StructType - from pyspark.sql.types import StructType + from pyspark.sql.types import StructType # noqa: PLC0415 + return StructType.fromJson(schema_json) except ImportError as e: if "pyspark" in str(e): @@ -238,7 +244,7 @@ def _process_schema(self, schema: Any) -> Any: return schema - def load(self) -> "DataFrame": + def load(self) -> DataFrame: """Load data using Spark""" spark = self._get_spark() @@ -246,9 +252,13 @@ def load(self) -> "DataFrame": if self._schema: reader = reader.schema(self._schema) - return reader.format(self.file_format).options(**self.load_args).load(self._spark_path) + return ( + reader.format(self.file_format) + .options(**self.load_args) + .load(self._spark_path) + ) - def save(self, data: "DataFrame") -> None: + def save(self, data: DataFrame) -> None: """Save data using Spark""" writer = data.write From 7514ccbfacec28c22883803ac14edcefb25e646d Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 22 Sep 2025 13:58:29 +0100 Subject: [PATCH 06/17] Update spark_dataset.py Signed-off-by: Sajid Alam --- kedro-datasets/kedro_datasets/spark/spark_dataset.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index 231e9e60d..d63c7c5ba 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -113,7 +113,9 @@ def __init__( # noqa: PLR0913 self._spark_path = self._to_spark_path(filepath) # Handle schema if provided - self._schema = self._process_schema(self.load_args.pop("schema", None)) + self._schema = SparkDataset._load_schema_from_file( + self.load_args.pop("schema", None) + ) super().__init__( filepath=PurePosixPath(self.path), @@ -208,7 +210,8 @@ def _get_spark(self) -> SparkSession: ) raise ImportError(msg) from e - def _process_schema(self, schema: Any) -> Any: + @staticmethod + def _load_schema_from_file(schema: Any) -> Any: """Process schema argument if provided""" if schema is None: return None From e98ce2a9685f79dd22ffb1b61f6990f8dcd04c97 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 24 Sep 2025 10:02:54 +0100 Subject: [PATCH 07/17] Update test_spark_dataset.py Signed-off-by: Sajid Alam --- .../tests/spark/test_spark_dataset.py | 360 ++---------------- 1 file changed, 36 insertions(+), 324 deletions(-) diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index 18bd1066f..226d3e3a3 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -26,11 +26,6 @@ ) from pyspark.sql.utils import AnalysisException -from kedro_datasets._utils.databricks_utils import ( - dbfs_exists, - dbfs_glob, - get_dbutils, -) from kedro_datasets.pandas import CSVDataset, ParquetDataset from kedro_datasets.pickle import PickleDataset from kedro_datasets.spark import SparkDataset @@ -41,28 +36,6 @@ SCHEMA_FILE_NAME = "schema.json" AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} -HDFS_PREFIX = f"{FOLDER_NAME}/{FILENAME}" -HDFS_FOLDER_STRUCTURE = [ - ( - HDFS_PREFIX, - [ - "2019-01-01T23.59.59.999Z", - "2019-01-02T00.00.00.000Z", - "2019-01-02T00.00.00.001Z", - "2019-01-02T01.00.00.000Z", - "2019-02-01T00.00.00.000Z", - ], - [], - ), - (HDFS_PREFIX + "/2019-01-01T23.59.59.999Z", [FILENAME], []), - (HDFS_PREFIX + "/2019-01-01T23.59.59.999Z/" + FILENAME, [], ["part1", "part2"]), - (HDFS_PREFIX + "/2019-01-02T00.00.00.000Z", [], ["other_file"]), - (HDFS_PREFIX + "/2019-01-02T00.00.00.001Z", [], []), - (HDFS_PREFIX + "/2019-01-02T01.00.00.000Z", [FILENAME], []), - (HDFS_PREFIX + "/2019-01-02T01.00.00.000Z/" + FILENAME, [], ["part1"]), - (HDFS_PREFIX + "/2019-02-01T00.00.00.000Z", [], ["other_file"]), -] - SPARK_VERSION = PackagingVersion(__version__) @@ -162,14 +135,6 @@ def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructT return f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}" -class FileInfo: - def __init__(self, path): - self.path = "dbfs:" + path - - def isDir(self): - return "." not in self.path.split("/")[-1] - - class TestSparkDataset: def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = (tmp_path / "data").as_posix() @@ -286,12 +251,9 @@ def test_load_options_invalid_schema_file(self, tmp_path): schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() Path(schemapath).write_text("dummy", encoding="utf-8") - pattern = ( - f"Contents of 'schema.filepath' ({schemapath}) are invalid. Please" - f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." - ) + pattern = f"Failed to load schema from {schemapath}" - with pytest.raises(DatasetError, match=re.escape(pattern)): + with pytest.raises(DatasetError, match=pattern): SparkDataset( filepath=filepath, file_format="csv", @@ -301,10 +263,7 @@ def test_load_options_invalid_schema_file(self, tmp_path): def test_load_options_invalid_schema(self, tmp_path): filepath = (tmp_path / "data").as_posix() - pattern = ( - "Schema load argument does not specify a 'filepath' attribute. Please" - "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." - ) + pattern = "Schema dict must have 'filepath'" with pytest.raises(DatasetError, match=pattern): SparkDataset( @@ -369,12 +328,11 @@ def test_save_overwrite_mode(self, tmp_path, sample_spark_df): def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): filepath = (tmp_path / "test_data").as_posix() pattern = ( - f"It is not possible to perform 'save()' for file format 'delta' " - f"with mode '{mode}' on 'SparkDataset'. " - f"Please use 'spark.DeltaTableDataset' instead." + f"Delta format doesn't support mode '{mode}'. " + f"Use one of" ) - with pytest.raises(DatasetError, match=re.escape(pattern)): + with pytest.raises(DatasetError, match=pattern): _ = SparkDataset( filepath=filepath, file_format="delta", save_args={"mode": mode} ) @@ -411,13 +369,15 @@ def test_exists_raises_error(self, mocker): # AnalysisExceptions clearly indicating a missing file spark_dataset = SparkDataset(filepath="") if SPARK_VERSION >= PackagingVersion("3.4.0"): - mocker.patch( - "kedro_datasets.spark.spark_dataset.get_spark", + mocker.patch.object( + spark_dataset, + "_get_spark", side_effect=AnalysisException("Other Exception"), ) else: - mocker.patch( - "kedro_datasets.spark.spark_dataset.get_spark", + mocker.patch.object( + spark_dataset, + "_get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"): @@ -439,26 +399,11 @@ def test_s3_glob_refresh(self): spark_dataset = SparkDataset(filepath="s3a://bucket/data") assert spark_dataset._glob_function.keywords == {"refresh": True} - def test_copy(self): - spark_dataset = SparkDataset( - filepath="/tmp/data", save_args={"mode": "overwrite"} - ) - assert spark_dataset._file_format == "parquet" - - spark_dataset_copy = spark_dataset._copy(_file_format="csv") - - assert spark_dataset is not spark_dataset_copy - assert spark_dataset._file_format == "parquet" - assert spark_dataset._save_args == {"mode": "overwrite"} - assert spark_dataset_copy._file_format == "csv" - assert spark_dataset_copy._save_args == {"mode": "overwrite"} - def test_dbfs_prefix_warning_no_databricks(self, caplog): # test that warning is not raised when not on Databricks filepath = "my_project/data/02_intermediate/processed_data" expected_message = ( - "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix in the " - f"filepath is a known source of error. You must add this prefix to {filepath}." + "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" ) SparkDataset(filepath="my_project/data/02_intermediate/processed_data") assert expected_message not in caplog.text @@ -477,6 +422,14 @@ def test_prefix_warning_on_databricks( ): monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + # Mock deployed_on_databricks to return True + import kedro_datasets._utils.databricks_utils + monkeypatch.setattr( + kedro_datasets._utils.databricks_utils, + "deployed_on_databricks", + lambda: True + ) + SparkDataset(filepath=filepath) warning_msg = ( @@ -627,91 +580,6 @@ def test_exists( ] * 2 assert mocked_glob.call_args_list == expected_calls - def test_dbfs_glob(self, mocker): - dbutils_mock = mocker.Mock() - dbutils_mock.fs.ls.return_value = [ - FileInfo("/tmp/file/date1"), - FileInfo("/tmp/file/date2"), - FileInfo("/tmp/file/file.csv"), - FileInfo("/tmp/file/"), - ] - pattern = "/tmp/file/*/file" - expected = ["/dbfs/tmp/file/date1/file", "/dbfs/tmp/file/date2/file"] - - result = dbfs_glob(pattern, dbutils_mock) - assert result == expected - dbutils_mock.fs.ls.assert_called_once_with("/tmp/file") - - def test_dbfs_exists(self, mocker): - dbutils_mock = mocker.Mock() - test_path = "/dbfs/tmp/file/date1/file" - dbutils_mock.fs.ls.return_value = [ - FileInfo("/tmp/file/date1"), - FileInfo("/tmp/file/date2"), - FileInfo("/tmp/file/file.csv"), - FileInfo("/tmp/file/"), - ] - - assert dbfs_exists(test_path, dbutils_mock) - - # add side effect to test that non-existence is handled - dbutils_mock.fs.ls.side_effect = Exception() - assert not dbfs_exists(test_path, dbutils_mock) - - def test_ds_init_no_dbutils(self, mocker): - get_dbutils_mock = mocker.patch( - "kedro_datasets.spark.spark_dataset.get_dbutils", - return_value=None, - ) - - dataset = SparkDataset(filepath="/dbfs/tmp/data") - - get_dbutils_mock.assert_called_once() - assert dataset._glob_function.__name__ == "iglob" - - def test_ds_init_dbutils_available(self, mocker): - get_dbutils_mock = mocker.patch( - "kedro_datasets.spark.spark_dataset.get_dbutils", - return_value="mock", - ) - - dataset = SparkDataset(filepath="/dbfs/tmp/data") - - get_dbutils_mock.assert_called_once() - assert dataset._glob_function.__class__.__name__ == "partial" - assert dataset._glob_function.func.__name__ == "dbfs_glob" - assert dataset._glob_function.keywords == { - "dbutils": get_dbutils_mock.return_value - } - - def test_get_dbutils_from_globals(self, mocker): - mocker.patch( - "kedro_datasets._utils.databricks_utils.globals", - return_value={"dbutils": "dbutils_from_globals"}, - ) - assert get_dbutils("spark") == "dbutils_from_globals" - - def test_get_dbutils_from_pyspark(self, mocker): - dbutils_mock = mocker.Mock() - dbutils_mock.DBUtils.return_value = "dbutils_from_pyspark" - mocker.patch.dict("sys.modules", {"pyspark.dbutils": dbutils_mock}) - assert get_dbutils("spark") == "dbutils_from_pyspark" - dbutils_mock.DBUtils.assert_called_once_with("spark") - - def test_get_dbutils_from_ipython(self, mocker): - ipython_mock = mocker.Mock() - ipython_mock.get_ipython.return_value.user_ns = { - "dbutils": "dbutils_from_ipython" - } - mocker.patch.dict("sys.modules", {"IPython": ipython_mock}) - assert get_dbutils("spark") == "dbutils_from_ipython" - ipython_mock.get_ipython.assert_called_once_with() - - def test_get_dbutils_no_modules(self, mocker): - mocker.patch("kedro_datasets.spark.spark_dataset.globals", return_value={}) - mocker.patch.dict("sys.modules", {"pyspark": None, "IPython": None}) - assert get_dbutils("spark") is None - @pytest.mark.parametrize("os_name", ["nt", "posix"]) def test_regular_path_in_different_os(self, os_name, mocker): """Check that class of filepath depends on OS for regular path.""" @@ -738,24 +606,21 @@ def test_no_version(self, versioned_dataset_s3): versioned_dataset_s3.load() def test_load_latest(self, mocker, versioned_dataset_s3): - get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset.get_spark", - ) + mocker.patch.object(versioned_dataset_s3, '_get_spark') mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") mocked_glob.return_value = [ "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") ] mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) + # Mock the actual Spark read + mock_spark = mocker.MagicMock() + mocker.patch.object(versioned_dataset_s3, '_get_spark', return_value=mock_spark) + versioned_dataset_s3.load() mocked_glob.assert_called_once_with(f"{BUCKET_NAME}/{FILENAME}/*/{FILENAME}") - get_spark.return_value.read.load.assert_called_once_with( - "s3a://{b}/{f}/{v}/{f}".format( - b=BUCKET_NAME, f=FILENAME, v="mocked_version" - ), - "parquet", - ) + mock_spark.read.load.assert_called_once() def test_load_exact(self, mocker): ts = generate_timestamp() @@ -763,14 +628,13 @@ def test_load_exact(self, mocker): filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=Version(ts, None), ) - get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset.get_spark", - ) + + mock_spark = mocker.MagicMock() + mocker.patch.object(ds_s3, '_get_spark', return_value=mock_spark) + ds_s3.load() - get_spark.return_value.read.load.assert_called_once_with( - f"s3a://{BUCKET_NAME}/{FILENAME}/{ts}/{FILENAME}", "parquet" - ) + mock_spark.read.load.assert_called_once() def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker): mocked_spark_df = mocker.Mock() @@ -783,10 +647,7 @@ def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker): # matches save version due to consistency check in versioned_dataset_s3.save() mocker.patch.object(ds_s3, "resolve_load_version", return_value=version.save) ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( - f"s3a://{BUCKET_NAME}/{FILENAME}/{version.save}/{FILENAME}", - "parquet", - ) + mocked_spark_df.write.save.assert_called_once() def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") @@ -802,10 +663,7 @@ def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mock ) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( - f"s3a://{BUCKET_NAME}/{FILENAME}/{exact_version.save}/{FILENAME}", - "parquet", - ) + mocked_spark_df.write.save.assert_called_once() def test_prevent_overwrite(self, mocker, versioned_dataset_s3): mocked_spark_df = mocker.Mock() @@ -821,162 +679,16 @@ def test_prevent_overwrite(self, mocker, versioned_dataset_s3): mocked_spark_df.write.save.assert_not_called() def test_repr(self, versioned_dataset_s3, version): - assert "filepath='s3a://" in str(versioned_dataset_s3) + assert "filepath=" in str(versioned_dataset_s3) assert f"version=Version(load=None, save='{version.save}')" in str( versioned_dataset_s3 ) dataset_s3 = SparkDataset(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") - assert "filepath='s3a://" in str(dataset_s3) + assert "filepath=" in str(dataset_s3) assert "version=" not in str(dataset_s3) -class TestSparkDatasetVersionedHdfs: - def test_no_version(self, mocker, version): - hdfs_walk = mocker.patch( - "kedro_datasets.spark.spark_dataset.InsecureClient.walk" - ) - hdfs_walk.return_value = [] - - versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - pattern = r"Did not find any versions for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_hdfs.load() - - hdfs_walk.assert_called_once_with(HDFS_PREFIX) - - def test_load_latest(self, mocker, version): - mocker.patch( - "kedro_datasets.spark.spark_dataset.InsecureClient.status", - return_value=True, - ) - hdfs_walk = mocker.patch( - "kedro_datasets.spark.spark_dataset.InsecureClient.walk" - ) - hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE - - versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset.get_spark", - ) - - versioned_hdfs.load() - - hdfs_walk.assert_called_once_with(HDFS_PREFIX) - get_spark.return_value.read.load.assert_called_once_with( - "hdfs://{fn}/{f}/{v}/{f}".format( - fn=FOLDER_NAME, v="2019-01-02T01.00.00.000Z", f=FILENAME - ), - "parquet", - ) - - def test_load_exact(self, mocker): - ts = generate_timestamp() - versioned_hdfs = SparkDataset( - filepath=f"hdfs://{HDFS_PREFIX}", version=Version(ts, None) - ) - get_spark = mocker.patch( - "kedro_datasets.spark.spark_dataset.get_spark", - ) - - versioned_hdfs.load() - - get_spark.return_value.read.load.assert_called_once_with( - f"hdfs://{FOLDER_NAME}/{FILENAME}/{ts}/{FILENAME}", - "parquet", - ) - - def test_save(self, mocker, version): - hdfs_status = mocker.patch( - "kedro_datasets.spark.spark_dataset.InsecureClient.status" - ) - hdfs_status.return_value = None - - versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - # need resolve_load_version() call to return a load version that - # matches save version due to consistency check in versioned_hdfs.save() - mocker.patch.object( - versioned_hdfs, "resolve_load_version", return_value=version.save - ) - - mocked_spark_df = mocker.Mock() - versioned_hdfs.save(mocked_spark_df) - - hdfs_status.assert_called_once_with( - f"{FOLDER_NAME}/{FILENAME}/{version.save}/{FILENAME}", - strict=False, - ) - mocked_spark_df.write.save.assert_called_once_with( - f"hdfs://{FOLDER_NAME}/{FILENAME}/{version.save}/{FILENAME}", - "parquet", - ) - - def test_save_version_warning(self, mocker): - exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - versioned_hdfs = SparkDataset( - filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version - ) - mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) - mocked_spark_df = mocker.Mock() - - pattern = ( - rf"Save version '{exact_version.save}' did not match load version " - rf"'{exact_version.load}' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" - ) - - with pytest.warns(UserWarning, match=pattern): - versioned_hdfs.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once_with( - f"hdfs://{FOLDER_NAME}/{FILENAME}/{exact_version.save}/{FILENAME}", - "parquet", - ) - - def test_prevent_overwrite(self, mocker, version): - hdfs_status = mocker.patch( - "kedro_datasets.spark.spark_dataset.InsecureClient.status" - ) - hdfs_status.return_value = True - - versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - mocked_spark_df = mocker.Mock() - - pattern = ( - r"Save path '.+' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\) must not exist " - r"if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_hdfs.save(mocked_spark_df) - - hdfs_status.assert_called_once_with( - f"{FOLDER_NAME}/{FILENAME}/{version.save}/{FILENAME}", - strict=False, - ) - mocked_spark_df.write.save.assert_not_called() - - def test_hdfs_warning(self, version): - pattern = ( - "HDFS filesystem support for versioned SparkDataset is in beta " - "and uses 'hdfs.client.InsecureClient', please use with caution" - ) - with pytest.warns(UserWarning, match=pattern): - SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - - def test_repr(self, version): - versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) - assert "filepath='hdfs://" in str(versioned_hdfs) - assert f"version=Version(load=None, save='{version.save}')" in str( - versioned_hdfs - ) - - dataset_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}") - assert "filepath='hdfs://" in str(dataset_hdfs) - assert "version=" not in str(dataset_hdfs) - - @pytest.fixture def data_catalog(tmp_path): source_path = Path(__file__).parent / "data/test.parquet" From f800f33bb70ca3232a3f575d406dbca4445b152f Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 24 Sep 2025 10:05:21 +0100 Subject: [PATCH 08/17] lint Signed-off-by: Sajid Alam --- kedro-datasets/tests/spark/test_spark_dataset.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index 226d3e3a3..61c921aef 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -26,6 +26,7 @@ ) from pyspark.sql.utils import AnalysisException +import kedro_datasets._utils.databricks_utils from kedro_datasets.pandas import CSVDataset, ParquetDataset from kedro_datasets.pickle import PickleDataset from kedro_datasets.spark import SparkDataset @@ -327,10 +328,7 @@ def test_save_overwrite_mode(self, tmp_path, sample_spark_df): @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): filepath = (tmp_path / "test_data").as_posix() - pattern = ( - f"Delta format doesn't support mode '{mode}'. " - f"Use one of" - ) + pattern = f"Delta format doesn't support mode '{mode}'. " f"Use one of" with pytest.raises(DatasetError, match=pattern): _ = SparkDataset( @@ -401,7 +399,6 @@ def test_s3_glob_refresh(self): def test_dbfs_prefix_warning_no_databricks(self, caplog): # test that warning is not raised when not on Databricks - filepath = "my_project/data/02_intermediate/processed_data" expected_message = ( "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" ) @@ -423,11 +420,10 @@ def test_prefix_warning_on_databricks( monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") # Mock deployed_on_databricks to return True - import kedro_datasets._utils.databricks_utils monkeypatch.setattr( kedro_datasets._utils.databricks_utils, "deployed_on_databricks", - lambda: True + lambda: True, ) SparkDataset(filepath=filepath) @@ -606,7 +602,7 @@ def test_no_version(self, versioned_dataset_s3): versioned_dataset_s3.load() def test_load_latest(self, mocker, versioned_dataset_s3): - mocker.patch.object(versioned_dataset_s3, '_get_spark') + mocker.patch.object(versioned_dataset_s3, "_get_spark") mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") mocked_glob.return_value = [ "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") @@ -615,7 +611,7 @@ def test_load_latest(self, mocker, versioned_dataset_s3): # Mock the actual Spark read mock_spark = mocker.MagicMock() - mocker.patch.object(versioned_dataset_s3, '_get_spark', return_value=mock_spark) + mocker.patch.object(versioned_dataset_s3, "_get_spark", return_value=mock_spark) versioned_dataset_s3.load() @@ -630,7 +626,7 @@ def test_load_exact(self, mocker): ) mock_spark = mocker.MagicMock() - mocker.patch.object(ds_s3, '_get_spark', return_value=mock_spark) + mocker.patch.object(ds_s3, "_get_spark", return_value=mock_spark) ds_s3.load() From c74ff125dce2baac7326ec11c4c600fa1de6a923 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 24 Sep 2025 15:23:20 +0100 Subject: [PATCH 09/17] revert and split sparkdataset rewrite into v2 Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset.py | 420 +++++----- .../kedro_datasets/spark/spark_dataset_v2.py | 309 ++++++++ .../tests/spark/test_spark_dataset.py | 362 ++++++++- .../tests/spark/test_spark_dataset_v2.py | 731 ++++++++++++++++++ 4 files changed, 1592 insertions(+), 230 deletions(-) create mode 100644 kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py create mode 100644 kedro-datasets/tests/spark/test_spark_dataset_v2.py diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_dataset.py index d63c7c5ba..0cd84f570 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset.py @@ -4,26 +4,87 @@ from __future__ import annotations +import json import logging -import os +from copy import deepcopy +from fnmatch import fnmatch +from functools import partial from pathlib import PurePosixPath -from typing import TYPE_CHECKING, Any +from typing import Any +from warnings import warn +import fsspec +from hdfs import HdfsError, InsecureClient from kedro.io.core import ( + CLOUD_PROTOCOLS, AbstractVersionedDataset, DatasetError, Version, + get_filepath_str, get_protocol_and_path, ) - -if TYPE_CHECKING: - from pyspark.sql import DataFrame, SparkSession - from pyspark.sql.types import StructType +from pyspark.sql import DataFrame +from pyspark.sql.types import StructType +from pyspark.sql.utils import AnalysisException +from s3fs import S3FileSystem + +from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + deployed_on_databricks, + get_dbutils, + parse_glob_pattern, + split_filepath, + strip_dbfs_prefix, +) +from kedro_datasets._utils.spark_utils import get_spark logger = logging.getLogger(__name__) -class SparkDataset(AbstractVersionedDataset): +class KedroHdfsInsecureClient(InsecureClient): + """Subclasses ``hdfs.InsecureClient`` and implements ``hdfs_exists`` + and ``hdfs_glob`` methods required by ``SparkDataset``""" + + def hdfs_exists(self, hdfs_path: str) -> bool: + """Determines whether given ``hdfs_path`` exists in HDFS. + + Args: + hdfs_path: Path to check. + + Returns: + True if ``hdfs_path`` exists in HDFS, False otherwise. + """ + return bool(self.status(hdfs_path, strict=False)) + + def hdfs_glob(self, pattern: str) -> list[str]: + """Perform a glob search in HDFS using the provided pattern. + + Args: + pattern: Glob pattern to search for. + + Returns: + List of HDFS paths that satisfy the glob pattern. + """ + prefix = parse_glob_pattern(pattern) or "/" + matched = set() + try: + for dpath, _, fnames in self.walk(prefix): + if fnmatch(dpath, pattern): + matched.add(dpath) + matched |= { + f"{dpath}/{fname}" + for fname in fnames + if fnmatch(f"{dpath}/{fname}", pattern) + } + except HdfsError: # pragma: no cover + # HdfsError is raised by `self.walk()` if prefix does not exist in HDFS. + # Ignore and return an empty list. + pass + return sorted(matched) + + +class SparkDataset(AbstractVersionedDataset[DataFrame, DataFrame]): """``SparkDataset`` loads and saves Spark dataframes. Examples: @@ -71,7 +132,7 @@ class SparkDataset(AbstractVersionedDataset): >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) >>> - >>> dataset = SparkDataset(filepath="tmp_path/test_data") + >>> dataset = SparkDataset(filepath=tmp_path / "test_data") >>> dataset.save(spark_df) >>> reloaded = dataset.load() >>> assert Row(name="Bob", age=12) in reloaded.take(4) @@ -96,214 +157,183 @@ def __init__( # noqa: PLR0913 version: Version | None = None, credentials: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, - ): - self.file_format = file_format - self.load_args = load_args or {} - self.save_args = save_args or {} - self.credentials = credentials or {} + ) -> None: + """Creates a new instance of ``SparkDataset``. + + Args: + filepath: Filepath in POSIX format to a Spark dataframe. When using Databricks + specify ``filepath``s starting with ``/dbfs/``. + file_format: File format used during load and save + operations. These are formats supported by the running + SparkContext include parquet, csv, delta. For a list of supported + formats please refer to Apache Spark documentation at + https://spark.apache.org/docs/latest/sql-programming-guide.html + load_args: Load args passed to Spark DataFrameReader load method. + It is dependent on the selected file format. You can find + a list of read options for each supported format + in Spark DataFrame read documentation: + https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html + save_args: Save args passed to Spark DataFrame write options. + Similar to load_args this is dependent on the selected file + format. You can pass ``mode`` and ``partitionBy`` to specify + your overwrite mode and partitioning respectively. You can find + a list of options for each format in Spark DataFrame + write documentation: + https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_df.html + version: If specified, should be an instance of + ``kedro.io.core.Version``. If its ``load`` attribute is + None, the latest version will be loaded. If its ``save`` + attribute is None, save version will be autogenerated. + credentials: Credentials to access the S3 bucket, such as + ``key``, ``secret``, if ``filepath`` prefix is ``s3a://`` or ``s3n://``. + Optional keyword arguments passed to ``hdfs.client.InsecureClient`` + if ``filepath`` prefix is ``hdfs://``. Ignored otherwise. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. + """ + credentials = deepcopy(credentials) or {} + fs_prefix, filepath = split_filepath(filepath) + path = PurePosixPath(filepath) + exists_function = None + glob_function = None self.metadata = metadata - # Parse filepath - self.protocol, self.path = get_protocol_and_path(filepath) - - # Get filesystem for metadata operations (exists, glob) - self._fs = self._get_filesystem() - - # Store Spark compatible path for I/O - self._spark_path = self._to_spark_path(filepath) + if ( + not (filepath.startswith("/dbfs") or filepath.startswith("/Volumes")) + and fs_prefix not in (protocol + "://" for protocol in CLOUD_PROTOCOLS) + and deployed_on_databricks() + ): + logger.warning( + "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix in the " + "filepath is a known source of error. You must add this prefix to %s", + filepath, + ) + if fs_prefix and fs_prefix in ("s3a://"): + _s3 = S3FileSystem(**credentials) + exists_function = _s3.exists + # Ensure cache is not used so latest version is retrieved correctly. + glob_function = partial(_s3.glob, refresh=True) + + elif fs_prefix == "hdfs://": + if version: + warn( + f"HDFS filesystem support for versioned {self.__class__.__name__} is " + f"in beta and uses 'hdfs.client.InsecureClient', please use with " + f"caution" + ) - # Handle schema if provided - self._schema = SparkDataset._load_schema_from_file( - self.load_args.pop("schema", None) - ) + # default namenode address + credentials.setdefault("url", "http://localhost:9870") + credentials.setdefault("user", "hadoop") + + _hdfs_client = KedroHdfsInsecureClient(**credentials) + exists_function = _hdfs_client.hdfs_exists + glob_function = _hdfs_client.hdfs_glob # type: ignore + + elif filepath.startswith("/dbfs/"): + # dbfs add prefix to Spark path by default + # See https://github.com/kedro-org/kedro-plugins/issues/117 + dbutils = get_dbutils(get_spark()) + if dbutils: + glob_function = partial(dbfs_glob, dbutils=dbutils) + exists_function = partial(dbfs_exists, dbutils=dbutils) + else: + filesystem = fsspec.filesystem(fs_prefix.strip("://"), **credentials) + exists_function = filesystem.exists + glob_function = filesystem.glob super().__init__( - filepath=PurePosixPath(self.path), + filepath=path, version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, + exists_function=exists_function, + glob_function=glob_function, ) - self._validate_delta_format() - - def _get_filesystem(self): - """Get fsspec filesystem with helpful errors for missing deps""" - try: - import fsspec # noqa: PLC0415 - except ImportError: - raise ImportError("fsspec is required") - - # Normalise protocols - protocol_map = { - "s3a": "s3", - "s3n": "s3", # Spark S3 variants - "dbfs": "file", # DBFS is mounted as local - "": "file", # Default to local - } - - fsspec_protocol = protocol_map.get(self.protocol, self.protocol) - - try: - return fsspec.filesystem(fsspec_protocol, **self.credentials) - except ImportError as e: - # Provide targeted help - if "s3fs" in str(e): - msg = "s3fs not installed. Install with: pip install 'kedro-datasets[spark-s3]'" - elif "gcsfs" in str(e): - msg = "gcsfs not installed. Install with: pip install gcsfs" - elif "adlfs" in str(e): - msg = "adlfs not installed. Install with: pip install adlfs" - else: - msg = str(e) - raise ImportError(msg) from e - - def _to_spark_path(self, filepath: str) -> str: - """Convert to Spark-compatible path format""" - filepath = str(filepath) # Convert PosixPath to string - protocol, path = get_protocol_and_path(filepath) - - # Handle special cases - if filepath.startswith("/dbfs/"): - # Databricks: /dbfs/path -> dbfs:/path - if "DATABRICKS_RUNTIME_VERSION" in os.environ: - return "dbfs:/" + filepath[6:] - return filepath - - # Map to Spark protocols - spark_protocols = { - "s3": "s3a", # Critical: Spark prefers s3a:// - "gs": "gs", - "abfs": "abfs", - "file": "", # Local paths don't need protocol - "": "", - } - - spark_protocol = spark_protocols.get(protocol, protocol) + # Handle default load and save arguments + self._load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} + self._save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} - if not spark_protocol: - return path - return f"{spark_protocol}://{path}" + # Handle schema load argument + self._schema = self._load_args.pop("schema", None) + if self._schema is not None: + if isinstance(self._schema, dict): + self._schema = self._load_schema_from_file(self._schema) - def _get_spark(self) -> SparkSession: - """Lazy load Spark with environment specific guidance""" - try: - from pyspark.sql import SparkSession # noqa: PLC0415 - - return SparkSession.builder.getOrCreate() - except ImportError as e: - # Detect environment and provide specific help - if "DATABRICKS_RUNTIME_VERSION" in os.environ: - msg = ( - "Cannot import PySpark on Databricks. This is usually a " - "databricks-connect conflict. Try:\n" - " pip uninstall pyspark\n" - " pip install databricks-connect" - ) - elif "EMR_RELEASE_LABEL" in os.environ: - msg = "PySpark should be pre-installed on EMR. Check your cluster configuration." - else: - msg = ( - "PySpark not installed. Install based on your environment:\n" - " Local: pip install 'kedro-datasets[spark-local]'\n" - " Databricks: Use pre-installed Spark or databricks-connect\n" - " Cloud: Check your platform's Spark setup" - ) - raise ImportError(msg) from e + self._file_format = file_format + self._fs_prefix = fs_prefix + self._handle_delta_format() @staticmethod - def _load_schema_from_file(schema: Any) -> Any: - """Process schema argument if provided""" - if schema is None: - return None - - if isinstance(schema, dict): - # Load from file - schema_path = schema.get("filepath") - if not schema_path: - raise DatasetError("Schema dict must have 'filepath'") - - # Use fsspec to load - import json # noqa: PLC0415 - - protocol, path = get_protocol_and_path(schema_path) - + def _load_schema_from_file(schema: dict[str, Any]) -> StructType: + filepath = schema.get("filepath") + if not filepath: + raise DatasetError( + "Schema load argument does not specify a 'filepath' attribute. Please" + "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." + ) + + credentials = deepcopy(schema.get("credentials")) or {} + protocol, schema_path = get_protocol_and_path(filepath) + file_system = fsspec.filesystem(protocol, **credentials) + pure_posix_path = PurePosixPath(schema_path) + load_path = get_filepath_str(pure_posix_path, protocol) + + # Open schema file + with file_system.open(load_path) as fs_file: try: - import fsspec # noqa: PLC0415 - - fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) - with fs.open(path, "r") as f: - schema_json = json.load(f) - - # Lazy import StructType - from pyspark.sql.types import StructType # noqa: PLC0415 - - return StructType.fromJson(schema_json) - except ImportError as e: - if "pyspark" in str(e): - raise ImportError("PySpark required to process schema") from e - raise - except Exception as e: - raise DatasetError(f"Failed to load schema from {schema_path}") from e + return StructType.fromJson(json.loads(fs_file.read())) + except Exception as exc: + raise DatasetError( + f"Contents of 'schema.filepath' ({schema_path}) are invalid. Please" + f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." + ) from exc - return schema + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._fs_prefix + str(self._filepath), + "file_format": self._file_format, + "load_args": self._load_args, + "save_args": self._save_args, + "version": self._version, + } def load(self) -> DataFrame: - """Load data using Spark""" - spark = self._get_spark() + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) + read_obj = get_spark().read - reader = spark.read + # Pass schema if defined if self._schema: - reader = reader.schema(self._schema) + read_obj = read_obj.schema(self._schema) - return ( - reader.format(self.file_format) - .options(**self.load_args) - .load(self._spark_path) - ) + return read_obj.load(load_path, self._file_format, **self._load_args) def save(self, data: DataFrame) -> None: - """Save data using Spark""" - writer = data.write - - if mode := self.save_args.pop("mode", None): - writer = writer.mode(mode) - - if partition_by := self.save_args.pop("partitionBy", None): - writer = writer.partitionBy(partition_by) - - writer.format(self.file_format).options(**self.save_args).save(self._spark_path) + save_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_save_path())) + data.write.save(save_path, self._file_format, **self._save_args) def _exists(self) -> bool: - """Existence check using fsspec""" - try: - return self._fs.exists(self.path) - except Exception: - # Fallback to Spark check for special cases (e.g., Delta tables) - if self.file_format == "delta": - try: - spark = self._get_spark() - spark.read.format("delta").load(self._spark_path) - return True - except Exception: - return False - return False - - def _validate_delta_format(self): - """Validate Delta-specific configurations""" - if self.file_format == "delta": - mode = self.save_args.get("mode") - supported = {"append", "overwrite", "error", "errorifexists", "ignore"} - if mode and mode not in supported: - raise DatasetError( - f"Delta format doesn't support mode '{mode}'. " - f"Use one of {supported} or DeltaTableDataset for advanced operations." - ) + load_path = strip_dbfs_prefix(self._fs_prefix + str(self._get_load_path())) - def _describe(self) -> dict[str, Any]: - return { - "filepath": self._spark_path, - "file_format": self.file_format, - "load_args": self.load_args, - "save_args": self.save_args, - "version": self._version, - } + try: + get_spark().read.load(load_path, self._file_format) + except AnalysisException as exception: + # `AnalysisException.desc` is deprecated with pyspark >= 3.4 + message = exception.desc if hasattr(exception, "desc") else str(exception) + if "Path does not exist:" in message or "is not a Delta table" in message: + return False + raise + return True + + def _handle_delta_format(self) -> None: + supported_modes = {"append", "overwrite", "error", "errorifexists", "ignore"} + write_mode = self._save_args.get("mode") + if ( + write_mode + and self._file_format == "delta" + and write_mode not in supported_modes + ): + raise DatasetError( + f"It is not possible to perform 'save()' for file format 'delta' " + f"with mode '{write_mode}' on 'SparkDataset'. " + f"Please use 'spark.DeltaTableDataset' instead." + ) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py new file mode 100644 index 000000000..dc225de69 --- /dev/null +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -0,0 +1,309 @@ +"""``AbstractVersionedDataset`` implementation to access Spark dataframes using +``pyspark``. +""" + +from __future__ import annotations + +import logging +import os +from pathlib import PurePosixPath +from typing import TYPE_CHECKING, Any + +from kedro.io.core import ( + AbstractVersionedDataset, + DatasetError, + Version, + get_protocol_and_path, +) + +if TYPE_CHECKING: + from pyspark.sql import DataFrame, SparkSession + from pyspark.sql.types import StructType + +logger = logging.getLogger(__name__) + + +class SparkDatasetV2(AbstractVersionedDataset): + """``SparkDatasetV2`` loads and saves Spark dataframes. + + Examples: + Using the [YAML API](https://docs.kedro.org/en/stable/catalog-data/data_catalog_yaml_examples/): + + ```yaml + weather: + type: spark.SparkDatasetV2 + filepath: s3a://your_bucket/data/01_raw/weather/* + file_format: csv + load_args: + header: True + inferSchema: True + save_args: + sep: '|' + header: True + + weather_with_schema: + type: spark.SparkDatasetV2 + filepath: s3a://your_bucket/data/01_raw/weather/* + file_format: csv + load_args: + header: True + schema: + filepath: path/to/schema.json + save_args: + sep: '|' + header: True + + weather_cleaned: + type: spark.SparkDatasetV2 + filepath: data/02_intermediate/data.parquet + file_format: parquet + ``` + + Using the [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): + + >>> from kedro_datasets.spark import SparkDatasetV2 + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType + >>> + >>> schema = StructType( + ... [StructField("name", StringType(), True), StructField("age", IntegerType(), True)] + ... ) + >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) + >>> + >>> dataset = SparkDatasetV2(filepath="tmp_path/test_data") + >>> dataset.save(spark_df) + >>> reloaded = dataset.load() + >>> assert Row(name="Bob", age=12) in reloaded.take(4) + + """ + + # this dataset cannot be used with ``ParallelRunner``, + # therefore it has the attribute ``_SINGLE_PROCESS = True`` + # for parallelism within a Spark pipeline please consider + # ``ThreadRunner`` instead + _SINGLE_PROCESS = True + DEFAULT_LOAD_ARGS: dict[str, Any] = {} + DEFAULT_SAVE_ARGS: dict[str, Any] = {} + + def __init__( # noqa: PLR0913 + self, + *, + filepath: str, + file_format: str = "parquet", + load_args: dict[str, Any] | None = None, + save_args: dict[str, Any] | None = None, + version: Version | None = None, + credentials: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ): + self.file_format = file_format + self.load_args = load_args or {} + self.save_args = save_args or {} + self.credentials = credentials or {} + self.metadata = metadata + + # Parse filepath + self.protocol, self.path = get_protocol_and_path(filepath) + + # Get filesystem for metadata operations (exists, glob) + self._fs = self._get_filesystem() + + # Store Spark compatible path for I/O + self._spark_path = self._to_spark_path(filepath) + + # Handle schema if provided + self._schema = SparkDatasetV2._load_schema_from_file( + self.load_args.pop("schema", None) + ) + + super().__init__( + filepath=PurePosixPath(self.path), + version=version, + exists_function=self._fs.exists, + glob_function=self._fs.glob, + ) + + self._validate_delta_format() + + def _get_filesystem(self): + """Get fsspec filesystem with helpful errors for missing deps""" + try: + import fsspec # noqa: PLC0415 + except ImportError: + raise ImportError("fsspec is required") + + # Normalise protocols + protocol_map = { + "s3a": "s3", + "s3n": "s3", # Spark S3 variants + "dbfs": "file", # DBFS is mounted as local + "": "file", # Default to local + } + + fsspec_protocol = protocol_map.get(self.protocol, self.protocol) + + try: + return fsspec.filesystem(fsspec_protocol, **self.credentials) + except ImportError as e: + # Provide targeted help + if "s3fs" in str(e): + msg = "s3fs not installed. Install with: pip install 'kedro-datasets[spark-s3]'" + elif "gcsfs" in str(e): + msg = "gcsfs not installed. Install with: pip install gcsfs" + elif "adlfs" in str(e): + msg = "adlfs not installed. Install with: pip install adlfs" + else: + msg = str(e) + raise ImportError(msg) from e + + def _to_spark_path(self, filepath: str) -> str: + """Convert to Spark-compatible path format""" + filepath = str(filepath) # Convert PosixPath to string + protocol, path = get_protocol_and_path(filepath) + + # Handle special cases + if filepath.startswith("/dbfs/"): + # Databricks: /dbfs/path -> dbfs:/path + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + return "dbfs:/" + filepath[6:] + return filepath + + # Map to Spark protocols + spark_protocols = { + "s3": "s3a", # Critical: Spark prefers s3a:// + "gs": "gs", + "abfs": "abfs", + "file": "", # Local paths don't need protocol + "": "", + } + + spark_protocol = spark_protocols.get(protocol, protocol) + + if not spark_protocol: + return path + return f"{spark_protocol}://{path}" + + def _get_spark(self) -> SparkSession: + """Lazy load Spark with environment specific guidance""" + try: + from pyspark.sql import SparkSession # noqa: PLC0415 + + return SparkSession.builder.getOrCreate() + except ImportError as e: + # Detect environment and provide specific help + if "DATABRICKS_RUNTIME_VERSION" in os.environ: + msg = ( + "Cannot import PySpark on Databricks. This is usually a " + "databricks-connect conflict. Try:\n" + " pip uninstall pyspark\n" + " pip install databricks-connect" + ) + elif "EMR_RELEASE_LABEL" in os.environ: + msg = "PySpark should be pre-installed on EMR. Check your cluster configuration." + else: + msg = ( + "PySpark not installed. Install based on your environment:\n" + " Local: pip install 'kedro-datasets[spark-local]'\n" + " Databricks: Use pre-installed Spark or databricks-connect\n" + " Cloud: Check your platform's Spark setup" + ) + raise ImportError(msg) from e + + @staticmethod + def _load_schema_from_file(schema: Any) -> Any: + """Process schema argument if provided""" + if schema is None: + return None + + if isinstance(schema, dict): + # Load from file + schema_path = schema.get("filepath") + if not schema_path: + raise DatasetError("Schema dict must have 'filepath'") + + # Use fsspec to load + import json # noqa: PLC0415 + + protocol, path = get_protocol_and_path(schema_path) + + try: + import fsspec # noqa: PLC0415 + + fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) + with fs.open(path, "r") as f: + schema_json = json.load(f) + + # Lazy import StructType + from pyspark.sql.types import StructType # noqa: PLC0415 + + return StructType.fromJson(schema_json) + except ImportError as e: + if "pyspark" in str(e): + raise ImportError("PySpark required to process schema") from e + raise + except Exception as e: + raise DatasetError(f"Failed to load schema from {schema_path}") from e + + return schema + + def load(self) -> DataFrame: + """Load data using Spark""" + spark = self._get_spark() + + reader = spark.read + if self._schema: + reader = reader.schema(self._schema) + + return ( + reader.format(self.file_format) + .options(**self.load_args) + .load(self._spark_path) + ) + + def save(self, data: DataFrame) -> None: + """Save data using Spark""" + writer = data.write + + if mode := self.save_args.pop("mode", None): + writer = writer.mode(mode) + + if partition_by := self.save_args.pop("partitionBy", None): + writer = writer.partitionBy(partition_by) + + writer.format(self.file_format).options(**self.save_args).save(self._spark_path) + + def _exists(self) -> bool: + """Existence check using fsspec""" + try: + return self._fs.exists(self.path) + except Exception: + # Fallback to Spark check for special cases (e.g., Delta tables) + if self.file_format == "delta": + try: + spark = self._get_spark() + spark.read.format("delta").load(self._spark_path) + return True + except Exception: + return False + return False + + def _validate_delta_format(self): + """Validate Delta-specific configurations""" + if self.file_format == "delta": + mode = self.save_args.get("mode") + supported = {"append", "overwrite", "error", "errorifexists", "ignore"} + if mode and mode not in supported: + raise DatasetError( + f"Delta format doesn't support mode '{mode}'. " + f"Use one of {supported} or DeltaTableDataset for advanced operations." + ) + + def _describe(self) -> dict[str, Any]: + return { + "filepath": self._spark_path, + "file_format": self.file_format, + "load_args": self.load_args, + "save_args": self.save_args, + "version": self._version, + } diff --git a/kedro-datasets/tests/spark/test_spark_dataset.py b/kedro-datasets/tests/spark/test_spark_dataset.py index 61c921aef..18bd1066f 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset.py +++ b/kedro-datasets/tests/spark/test_spark_dataset.py @@ -26,7 +26,11 @@ ) from pyspark.sql.utils import AnalysisException -import kedro_datasets._utils.databricks_utils +from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + get_dbutils, +) from kedro_datasets.pandas import CSVDataset, ParquetDataset from kedro_datasets.pickle import PickleDataset from kedro_datasets.spark import SparkDataset @@ -37,6 +41,28 @@ SCHEMA_FILE_NAME = "schema.json" AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} +HDFS_PREFIX = f"{FOLDER_NAME}/{FILENAME}" +HDFS_FOLDER_STRUCTURE = [ + ( + HDFS_PREFIX, + [ + "2019-01-01T23.59.59.999Z", + "2019-01-02T00.00.00.000Z", + "2019-01-02T00.00.00.001Z", + "2019-01-02T01.00.00.000Z", + "2019-02-01T00.00.00.000Z", + ], + [], + ), + (HDFS_PREFIX + "/2019-01-01T23.59.59.999Z", [FILENAME], []), + (HDFS_PREFIX + "/2019-01-01T23.59.59.999Z/" + FILENAME, [], ["part1", "part2"]), + (HDFS_PREFIX + "/2019-01-02T00.00.00.000Z", [], ["other_file"]), + (HDFS_PREFIX + "/2019-01-02T00.00.00.001Z", [], []), + (HDFS_PREFIX + "/2019-01-02T01.00.00.000Z", [FILENAME], []), + (HDFS_PREFIX + "/2019-01-02T01.00.00.000Z/" + FILENAME, [], ["part1"]), + (HDFS_PREFIX + "/2019-02-01T00.00.00.000Z", [], ["other_file"]), +] + SPARK_VERSION = PackagingVersion(__version__) @@ -136,6 +162,14 @@ def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructT return f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}" +class FileInfo: + def __init__(self, path): + self.path = "dbfs:" + path + + def isDir(self): + return "." not in self.path.split("/")[-1] + + class TestSparkDataset: def test_load_parquet(self, tmp_path, sample_pandas_df): temp_path = (tmp_path / "data").as_posix() @@ -252,9 +286,12 @@ def test_load_options_invalid_schema_file(self, tmp_path): schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() Path(schemapath).write_text("dummy", encoding="utf-8") - pattern = f"Failed to load schema from {schemapath}" + pattern = ( + f"Contents of 'schema.filepath' ({schemapath}) are invalid. Please" + f"provide a valid JSON-serialised 'pyspark.sql.types.StructType'." + ) - with pytest.raises(DatasetError, match=pattern): + with pytest.raises(DatasetError, match=re.escape(pattern)): SparkDataset( filepath=filepath, file_format="csv", @@ -264,7 +301,10 @@ def test_load_options_invalid_schema_file(self, tmp_path): def test_load_options_invalid_schema(self, tmp_path): filepath = (tmp_path / "data").as_posix() - pattern = "Schema dict must have 'filepath'" + pattern = ( + "Schema load argument does not specify a 'filepath' attribute. Please" + "include a path to a JSON-serialised 'pyspark.sql.types.StructType'." + ) with pytest.raises(DatasetError, match=pattern): SparkDataset( @@ -328,9 +368,13 @@ def test_save_overwrite_mode(self, tmp_path, sample_spark_df): @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): filepath = (tmp_path / "test_data").as_posix() - pattern = f"Delta format doesn't support mode '{mode}'. " f"Use one of" + pattern = ( + f"It is not possible to perform 'save()' for file format 'delta' " + f"with mode '{mode}' on 'SparkDataset'. " + f"Please use 'spark.DeltaTableDataset' instead." + ) - with pytest.raises(DatasetError, match=pattern): + with pytest.raises(DatasetError, match=re.escape(pattern)): _ = SparkDataset( filepath=filepath, file_format="delta", save_args={"mode": mode} ) @@ -367,15 +411,13 @@ def test_exists_raises_error(self, mocker): # AnalysisExceptions clearly indicating a missing file spark_dataset = SparkDataset(filepath="") if SPARK_VERSION >= PackagingVersion("3.4.0"): - mocker.patch.object( - spark_dataset, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.spark_dataset.get_spark", side_effect=AnalysisException("Other Exception"), ) else: - mocker.patch.object( - spark_dataset, - "_get_spark", + mocker.patch( + "kedro_datasets.spark.spark_dataset.get_spark", side_effect=AnalysisException("Other Exception", []), ) with pytest.raises(DatasetError, match="Other Exception"): @@ -397,10 +439,26 @@ def test_s3_glob_refresh(self): spark_dataset = SparkDataset(filepath="s3a://bucket/data") assert spark_dataset._glob_function.keywords == {"refresh": True} + def test_copy(self): + spark_dataset = SparkDataset( + filepath="/tmp/data", save_args={"mode": "overwrite"} + ) + assert spark_dataset._file_format == "parquet" + + spark_dataset_copy = spark_dataset._copy(_file_format="csv") + + assert spark_dataset is not spark_dataset_copy + assert spark_dataset._file_format == "parquet" + assert spark_dataset._save_args == {"mode": "overwrite"} + assert spark_dataset_copy._file_format == "csv" + assert spark_dataset_copy._save_args == {"mode": "overwrite"} + def test_dbfs_prefix_warning_no_databricks(self, caplog): # test that warning is not raised when not on Databricks + filepath = "my_project/data/02_intermediate/processed_data" expected_message = ( - "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" + "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix in the " + f"filepath is a known source of error. You must add this prefix to {filepath}." ) SparkDataset(filepath="my_project/data/02_intermediate/processed_data") assert expected_message not in caplog.text @@ -419,13 +477,6 @@ def test_prefix_warning_on_databricks( ): monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") - # Mock deployed_on_databricks to return True - monkeypatch.setattr( - kedro_datasets._utils.databricks_utils, - "deployed_on_databricks", - lambda: True, - ) - SparkDataset(filepath=filepath) warning_msg = ( @@ -576,6 +627,91 @@ def test_exists( ] * 2 assert mocked_glob.call_args_list == expected_calls + def test_dbfs_glob(self, mocker): + dbutils_mock = mocker.Mock() + dbutils_mock.fs.ls.return_value = [ + FileInfo("/tmp/file/date1"), + FileInfo("/tmp/file/date2"), + FileInfo("/tmp/file/file.csv"), + FileInfo("/tmp/file/"), + ] + pattern = "/tmp/file/*/file" + expected = ["/dbfs/tmp/file/date1/file", "/dbfs/tmp/file/date2/file"] + + result = dbfs_glob(pattern, dbutils_mock) + assert result == expected + dbutils_mock.fs.ls.assert_called_once_with("/tmp/file") + + def test_dbfs_exists(self, mocker): + dbutils_mock = mocker.Mock() + test_path = "/dbfs/tmp/file/date1/file" + dbutils_mock.fs.ls.return_value = [ + FileInfo("/tmp/file/date1"), + FileInfo("/tmp/file/date2"), + FileInfo("/tmp/file/file.csv"), + FileInfo("/tmp/file/"), + ] + + assert dbfs_exists(test_path, dbutils_mock) + + # add side effect to test that non-existence is handled + dbutils_mock.fs.ls.side_effect = Exception() + assert not dbfs_exists(test_path, dbutils_mock) + + def test_ds_init_no_dbutils(self, mocker): + get_dbutils_mock = mocker.patch( + "kedro_datasets.spark.spark_dataset.get_dbutils", + return_value=None, + ) + + dataset = SparkDataset(filepath="/dbfs/tmp/data") + + get_dbutils_mock.assert_called_once() + assert dataset._glob_function.__name__ == "iglob" + + def test_ds_init_dbutils_available(self, mocker): + get_dbutils_mock = mocker.patch( + "kedro_datasets.spark.spark_dataset.get_dbutils", + return_value="mock", + ) + + dataset = SparkDataset(filepath="/dbfs/tmp/data") + + get_dbutils_mock.assert_called_once() + assert dataset._glob_function.__class__.__name__ == "partial" + assert dataset._glob_function.func.__name__ == "dbfs_glob" + assert dataset._glob_function.keywords == { + "dbutils": get_dbutils_mock.return_value + } + + def test_get_dbutils_from_globals(self, mocker): + mocker.patch( + "kedro_datasets._utils.databricks_utils.globals", + return_value={"dbutils": "dbutils_from_globals"}, + ) + assert get_dbutils("spark") == "dbutils_from_globals" + + def test_get_dbutils_from_pyspark(self, mocker): + dbutils_mock = mocker.Mock() + dbutils_mock.DBUtils.return_value = "dbutils_from_pyspark" + mocker.patch.dict("sys.modules", {"pyspark.dbutils": dbutils_mock}) + assert get_dbutils("spark") == "dbutils_from_pyspark" + dbutils_mock.DBUtils.assert_called_once_with("spark") + + def test_get_dbutils_from_ipython(self, mocker): + ipython_mock = mocker.Mock() + ipython_mock.get_ipython.return_value.user_ns = { + "dbutils": "dbutils_from_ipython" + } + mocker.patch.dict("sys.modules", {"IPython": ipython_mock}) + assert get_dbutils("spark") == "dbutils_from_ipython" + ipython_mock.get_ipython.assert_called_once_with() + + def test_get_dbutils_no_modules(self, mocker): + mocker.patch("kedro_datasets.spark.spark_dataset.globals", return_value={}) + mocker.patch.dict("sys.modules", {"pyspark": None, "IPython": None}) + assert get_dbutils("spark") is None + @pytest.mark.parametrize("os_name", ["nt", "posix"]) def test_regular_path_in_different_os(self, os_name, mocker): """Check that class of filepath depends on OS for regular path.""" @@ -602,21 +738,24 @@ def test_no_version(self, versioned_dataset_s3): versioned_dataset_s3.load() def test_load_latest(self, mocker, versioned_dataset_s3): - mocker.patch.object(versioned_dataset_s3, "_get_spark") + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset.get_spark", + ) mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") mocked_glob.return_value = [ "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") ] mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) - # Mock the actual Spark read - mock_spark = mocker.MagicMock() - mocker.patch.object(versioned_dataset_s3, "_get_spark", return_value=mock_spark) - versioned_dataset_s3.load() mocked_glob.assert_called_once_with(f"{BUCKET_NAME}/{FILENAME}/*/{FILENAME}") - mock_spark.read.load.assert_called_once() + get_spark.return_value.read.load.assert_called_once_with( + "s3a://{b}/{f}/{v}/{f}".format( + b=BUCKET_NAME, f=FILENAME, v="mocked_version" + ), + "parquet", + ) def test_load_exact(self, mocker): ts = generate_timestamp() @@ -624,13 +763,14 @@ def test_load_exact(self, mocker): filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=Version(ts, None), ) - - mock_spark = mocker.MagicMock() - mocker.patch.object(ds_s3, "_get_spark", return_value=mock_spark) - + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset.get_spark", + ) ds_s3.load() - mock_spark.read.load.assert_called_once() + get_spark.return_value.read.load.assert_called_once_with( + f"s3a://{BUCKET_NAME}/{FILENAME}/{ts}/{FILENAME}", "parquet" + ) def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker): mocked_spark_df = mocker.Mock() @@ -643,7 +783,10 @@ def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker): # matches save version due to consistency check in versioned_dataset_s3.save() mocker.patch.object(ds_s3, "resolve_load_version", return_value=version.save) ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once() + mocked_spark_df.write.save.assert_called_once_with( + f"s3a://{BUCKET_NAME}/{FILENAME}/{version.save}/{FILENAME}", + "parquet", + ) def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mocker): exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") @@ -659,7 +802,10 @@ def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mock ) with pytest.warns(UserWarning, match=pattern): ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once() + mocked_spark_df.write.save.assert_called_once_with( + f"s3a://{BUCKET_NAME}/{FILENAME}/{exact_version.save}/{FILENAME}", + "parquet", + ) def test_prevent_overwrite(self, mocker, versioned_dataset_s3): mocked_spark_df = mocker.Mock() @@ -675,16 +821,162 @@ def test_prevent_overwrite(self, mocker, versioned_dataset_s3): mocked_spark_df.write.save.assert_not_called() def test_repr(self, versioned_dataset_s3, version): - assert "filepath=" in str(versioned_dataset_s3) + assert "filepath='s3a://" in str(versioned_dataset_s3) assert f"version=Version(load=None, save='{version.save}')" in str( versioned_dataset_s3 ) dataset_s3 = SparkDataset(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") - assert "filepath=" in str(dataset_s3) + assert "filepath='s3a://" in str(dataset_s3) assert "version=" not in str(dataset_s3) +class TestSparkDatasetVersionedHdfs: + def test_no_version(self, mocker, version): + hdfs_walk = mocker.patch( + "kedro_datasets.spark.spark_dataset.InsecureClient.walk" + ) + hdfs_walk.return_value = [] + + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + + pattern = r"Did not find any versions for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_hdfs.load() + + hdfs_walk.assert_called_once_with(HDFS_PREFIX) + + def test_load_latest(self, mocker, version): + mocker.patch( + "kedro_datasets.spark.spark_dataset.InsecureClient.status", + return_value=True, + ) + hdfs_walk = mocker.patch( + "kedro_datasets.spark.spark_dataset.InsecureClient.walk" + ) + hdfs_walk.return_value = HDFS_FOLDER_STRUCTURE + + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset.get_spark", + ) + + versioned_hdfs.load() + + hdfs_walk.assert_called_once_with(HDFS_PREFIX) + get_spark.return_value.read.load.assert_called_once_with( + "hdfs://{fn}/{f}/{v}/{f}".format( + fn=FOLDER_NAME, v="2019-01-02T01.00.00.000Z", f=FILENAME + ), + "parquet", + ) + + def test_load_exact(self, mocker): + ts = generate_timestamp() + versioned_hdfs = SparkDataset( + filepath=f"hdfs://{HDFS_PREFIX}", version=Version(ts, None) + ) + get_spark = mocker.patch( + "kedro_datasets.spark.spark_dataset.get_spark", + ) + + versioned_hdfs.load() + + get_spark.return_value.read.load.assert_called_once_with( + f"hdfs://{FOLDER_NAME}/{FILENAME}/{ts}/{FILENAME}", + "parquet", + ) + + def test_save(self, mocker, version): + hdfs_status = mocker.patch( + "kedro_datasets.spark.spark_dataset.InsecureClient.status" + ) + hdfs_status.return_value = None + + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + + # need resolve_load_version() call to return a load version that + # matches save version due to consistency check in versioned_hdfs.save() + mocker.patch.object( + versioned_hdfs, "resolve_load_version", return_value=version.save + ) + + mocked_spark_df = mocker.Mock() + versioned_hdfs.save(mocked_spark_df) + + hdfs_status.assert_called_once_with( + f"{FOLDER_NAME}/{FILENAME}/{version.save}/{FILENAME}", + strict=False, + ) + mocked_spark_df.write.save.assert_called_once_with( + f"hdfs://{FOLDER_NAME}/{FILENAME}/{version.save}/{FILENAME}", + "parquet", + ) + + def test_save_version_warning(self, mocker): + exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") + versioned_hdfs = SparkDataset( + filepath=f"hdfs://{HDFS_PREFIX}", version=exact_version + ) + mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False) + mocked_spark_df = mocker.Mock() + + pattern = ( + rf"Save version '{exact_version.save}' did not match load version " + rf"'{exact_version.load}' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" + ) + + with pytest.warns(UserWarning, match=pattern): + versioned_hdfs.save(mocked_spark_df) + mocked_spark_df.write.save.assert_called_once_with( + f"hdfs://{FOLDER_NAME}/{FILENAME}/{exact_version.save}/{FILENAME}", + "parquet", + ) + + def test_prevent_overwrite(self, mocker, version): + hdfs_status = mocker.patch( + "kedro_datasets.spark.spark_dataset.InsecureClient.status" + ) + hdfs_status.return_value = True + + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + + mocked_spark_df = mocker.Mock() + + pattern = ( + r"Save path '.+' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\) must not exist " + r"if versioning is enabled" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_hdfs.save(mocked_spark_df) + + hdfs_status.assert_called_once_with( + f"{FOLDER_NAME}/{FILENAME}/{version.save}/{FILENAME}", + strict=False, + ) + mocked_spark_df.write.save.assert_not_called() + + def test_hdfs_warning(self, version): + pattern = ( + "HDFS filesystem support for versioned SparkDataset is in beta " + "and uses 'hdfs.client.InsecureClient', please use with caution" + ) + with pytest.warns(UserWarning, match=pattern): + SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + + def test_repr(self, version): + versioned_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}", version=version) + assert "filepath='hdfs://" in str(versioned_hdfs) + assert f"version=Version(load=None, save='{version.save}')" in str( + versioned_hdfs + ) + + dataset_hdfs = SparkDataset(filepath=f"hdfs://{HDFS_PREFIX}") + assert "filepath='hdfs://" in str(dataset_hdfs) + assert "version=" not in str(dataset_hdfs) + + @pytest.fixture def data_catalog(tmp_path): source_path = Path(__file__).parent / "data/test.parquet" diff --git a/kedro-datasets/tests/spark/test_spark_dataset_v2.py b/kedro-datasets/tests/spark/test_spark_dataset_v2.py new file mode 100644 index 000000000..61c921aef --- /dev/null +++ b/kedro-datasets/tests/spark/test_spark_dataset_v2.py @@ -0,0 +1,731 @@ +import os +import re +import sys +import tempfile +from pathlib import Path, PurePosixPath + +import boto3 +import pandas as pd +import pytest +from kedro.io import DataCatalog, Version +from kedro.io.core import DatasetError, generate_timestamp +from kedro.io.data_catalog import SharedMemoryDataCatalog +from kedro.pipeline import node, pipeline +from kedro.runner import ParallelRunner, SequentialRunner +from moto import mock_aws +from packaging.version import Version as PackagingVersion +from pyspark import __version__ +from pyspark.sql import SparkSession +from pyspark.sql.functions import col +from pyspark.sql.types import ( + FloatType, + IntegerType, + StringType, + StructField, + StructType, +) +from pyspark.sql.utils import AnalysisException + +import kedro_datasets._utils.databricks_utils +from kedro_datasets.pandas import CSVDataset, ParquetDataset +from kedro_datasets.pickle import PickleDataset +from kedro_datasets.spark import SparkDataset + +FOLDER_NAME = "fake_folder" +FILENAME = "test.parquet" +BUCKET_NAME = "test_bucket" +SCHEMA_FILE_NAME = "schema.json" +AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} + +SPARK_VERSION = PackagingVersion(__version__) + + +@pytest.fixture +def sample_pandas_df() -> pd.DataFrame: + return pd.DataFrame( + {"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def version(): + load_version = None # use latest + save_version = generate_timestamp() # freeze save version + return Version(load_version, save_version) + + +@pytest.fixture +def versioned_dataset_local(tmp_path, version): + return SparkDataset(filepath=(tmp_path / FILENAME).as_posix(), version=version) + + +@pytest.fixture +def versioned_dataset_dbfs(tmp_path, version): + return SparkDataset( + filepath="/dbfs" + (tmp_path / FILENAME).as_posix(), version=version + ) + + +@pytest.fixture +def versioned_dataset_s3(version): + return SparkDataset( + filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", + version=version, + credentials=AWS_CREDENTIALS, + ) + + +@pytest.fixture +def sample_spark_df(): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return SparkSession.builder.getOrCreate().createDataFrame(data, schema) + + +@pytest.fixture +def sample_spark_df_schema() -> StructType: + return StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", FloatType(), True), + ] + ) + + +def identity(arg): + return arg # pragma: no cover + + +@pytest.fixture +def spark_in(tmp_path, sample_spark_df): + spark_in = SparkDataset(filepath=(tmp_path / "input").as_posix()) + spark_in.save(sample_spark_df) + return spark_in + + +@pytest.fixture +def mocked_s3_bucket(): + """Create a bucket for testing using moto.""" + with mock_aws(): + conn = boto3.client( + "s3", + aws_access_key_id=AWS_CREDENTIALS["key"], + aws_secret_access_key=AWS_CREDENTIALS["secret"], + ) + conn.create_bucket(Bucket=BUCKET_NAME) + yield conn + + +@pytest.fixture +def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructType): + """Creates schema file and adds it to mocked S3 bucket.""" + temporary_path = tmp_path / SCHEMA_FILE_NAME + temporary_path.write_text(sample_spark_df_schema.json(), encoding="utf-8") + + mocked_s3_bucket.put_object( + Bucket=BUCKET_NAME, Key=SCHEMA_FILE_NAME, Body=temporary_path.read_bytes() + ) + return f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}" + + +class TestSparkDataset: + def test_load_parquet(self, tmp_path, sample_pandas_df): + temp_path = (tmp_path / "data").as_posix() + local_parquet_set = ParquetDataset(filepath=temp_path) + local_parquet_set.save(sample_pandas_df) + spark_dataset = SparkDataset(filepath=temp_path) + spark_df = spark_dataset.load() + assert spark_df.count() == 4 + + def test_save_parquet(self, tmp_path, sample_spark_df): + # To cross check the correct Spark save operation we save to + # a single spark partition and retrieve it with Kedro + # ParquetDataset + temp_dir = Path(str(tmp_path / "test_data")) + spark_dataset = SparkDataset( + filepath=temp_dir.as_posix(), save_args={"compression": "none"} + ) + spark_df = sample_spark_df.coalesce(1) + spark_dataset.save(spark_df) + + single_parquet = [ + f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") + ][0] + + local_parquet_dataset = ParquetDataset(filepath=single_parquet.as_posix()) + + pandas_df = local_parquet_dataset.load() + + assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12 + + def test_load_options_csv(self, tmp_path, sample_pandas_df): + filepath = (tmp_path / "data").as_posix() + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + spark_dataset = SparkDataset( + filepath=filepath, file_format="csv", load_args={"header": True} + ) + spark_df = spark_dataset.load() + assert spark_df.filter(col("Name") == "Alex").count() == 1 + + def test_load_options_schema_ddl_string( + self, tmp_path, sample_pandas_df, sample_spark_df_schema + ): + filepath = (tmp_path / "data").as_posix() + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + spark_dataset = SparkDataset( + filepath=filepath, + file_format="csv", + load_args={"header": True, "schema": "name STRING, age INT, height FLOAT"}, + ) + spark_df = spark_dataset.load() + assert spark_df.schema == sample_spark_df_schema + + def test_load_options_schema_obj( + self, tmp_path, sample_pandas_df, sample_spark_df_schema + ): + filepath = (tmp_path / "data").as_posix() + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + + spark_dataset = SparkDataset( + filepath=filepath, + file_format="csv", + load_args={"header": True, "schema": sample_spark_df_schema}, + ) + + spark_df = spark_dataset.load() + assert spark_df.schema == sample_spark_df_schema + + def test_load_options_schema_path( + self, tmp_path, sample_pandas_df, sample_spark_df_schema + ): + filepath = (tmp_path / "data").as_posix() + schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + Path(schemapath).write_text(sample_spark_df_schema.json(), encoding="utf-8") + + spark_dataset = SparkDataset( + filepath=filepath, + file_format="csv", + load_args={"header": True, "schema": {"filepath": schemapath}}, + ) + + spark_df = spark_dataset.load() + assert spark_df.schema == sample_spark_df_schema + + @pytest.mark.usefixtures("mocked_s3_schema") + def test_load_options_schema_path_with_credentials( + self, tmp_path, sample_pandas_df, sample_spark_df_schema + ): + filepath = (tmp_path / "data").as_posix() + local_csv_dataset = CSVDataset(filepath=filepath) + local_csv_dataset.save(sample_pandas_df) + + spark_dataset = SparkDataset( + filepath=filepath, + file_format="csv", + load_args={ + "header": True, + "schema": { + "filepath": f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}", + "credentials": AWS_CREDENTIALS, + }, + }, + ) + + spark_df = spark_dataset.load() + assert spark_df.schema == sample_spark_df_schema + + def test_load_options_invalid_schema_file(self, tmp_path): + filepath = (tmp_path / "data").as_posix() + schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() + Path(schemapath).write_text("dummy", encoding="utf-8") + + pattern = f"Failed to load schema from {schemapath}" + + with pytest.raises(DatasetError, match=pattern): + SparkDataset( + filepath=filepath, + file_format="csv", + load_args={"header": True, "schema": {"filepath": schemapath}}, + ) + + def test_load_options_invalid_schema(self, tmp_path): + filepath = (tmp_path / "data").as_posix() + + pattern = "Schema dict must have 'filepath'" + + with pytest.raises(DatasetError, match=pattern): + SparkDataset( + filepath=filepath, + file_format="csv", + load_args={"header": True, "schema": {}}, + ) + + def test_save_options_csv(self, tmp_path, sample_spark_df): + # To cross check the correct Spark save operation we save to + # a single spark partition with csv format and retrieve it with Kedro + # CSVDataset + temp_dir = Path(str(tmp_path / "test_data")) + spark_dataset = SparkDataset( + filepath=temp_dir.as_posix(), + file_format="csv", + save_args={"sep": "|", "header": True}, + ) + spark_df = sample_spark_df.coalesce(1) + spark_dataset.save(spark_df) + + single_csv_file = [ + f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" + ][0] + + csv_local_dataset = CSVDataset( + filepath=single_csv_file.as_posix(), load_args={"sep": "|"} + ) + pandas_df = csv_local_dataset.load() + + assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31 + + def test_str_representation(self): + with tempfile.NamedTemporaryFile() as temp_data_file: + filepath = Path(temp_data_file.name).as_posix() + spark_dataset = SparkDataset( + filepath=filepath, file_format="csv", load_args={"header": True} + ) + assert "kedro_datasets.spark.spark_dataset.SparkDataset" in str( + spark_dataset + ) + assert f"filepath='{filepath}" in str(spark_dataset) + + def test_save_overwrite_fail(self, tmp_path, sample_spark_df): + # Writes a data frame twice and expects it to fail. + filepath = (tmp_path / "test_data").as_posix() + spark_dataset = SparkDataset(filepath=filepath) + spark_dataset.save(sample_spark_df) + + with pytest.raises(DatasetError): + spark_dataset.save(sample_spark_df) + + def test_save_overwrite_mode(self, tmp_path, sample_spark_df): + # Writes a data frame in overwrite mode. + filepath = (tmp_path / "test_data").as_posix() + spark_dataset = SparkDataset(filepath=filepath, save_args={"mode": "overwrite"}) + + spark_dataset.save(sample_spark_df) + spark_dataset.save(sample_spark_df) + + @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) + def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): + filepath = (tmp_path / "test_data").as_posix() + pattern = f"Delta format doesn't support mode '{mode}'. " f"Use one of" + + with pytest.raises(DatasetError, match=pattern): + _ = SparkDataset( + filepath=filepath, file_format="delta", save_args={"mode": mode} + ) + + def test_save_partition(self, tmp_path, sample_spark_df): + # To verify partitioning this test will partition the data by one + # of the columns and then check whether partitioned column is added + # to the save path + + filepath = Path(str(tmp_path / "test_data")) + spark_dataset = SparkDataset( + filepath=filepath.as_posix(), + save_args={"mode": "overwrite", "partitionBy": ["name"]}, + ) + + spark_dataset.save(sample_spark_df) + + expected_path = filepath / "name=Alex" + + assert expected_path.exists() + + @pytest.mark.parametrize("file_format", ["csv", "parquet", "delta"]) + def test_exists(self, file_format, tmp_path, sample_spark_df): + filepath = (tmp_path / "test_data").as_posix() + spark_dataset = SparkDataset(filepath=filepath, file_format=file_format) + + assert not spark_dataset.exists() + + spark_dataset.save(sample_spark_df) + assert spark_dataset.exists() + + def test_exists_raises_error(self, mocker): + # exists should raise all errors except for + # AnalysisExceptions clearly indicating a missing file + spark_dataset = SparkDataset(filepath="") + if SPARK_VERSION >= PackagingVersion("3.4.0"): + mocker.patch.object( + spark_dataset, + "_get_spark", + side_effect=AnalysisException("Other Exception"), + ) + else: + mocker.patch.object( + spark_dataset, + "_get_spark", + side_effect=AnalysisException("Other Exception", []), + ) + with pytest.raises(DatasetError, match="Other Exception"): + spark_dataset.exists() + + @pytest.mark.parametrize("is_async", [False, True]) + def test_parallel_runner(self, is_async, spark_in): + """Test ParallelRunner with SparkDataset fails.""" + catalog = SharedMemoryDataCatalog({"spark_in": spark_in}) + test_pipeline = pipeline([node(identity, "spark_in", "spark_out")]) + pattern = ( + r"The following datasets cannot be used with " + r"multiprocessing: \['spark_in'\]" + ) + with pytest.raises(AttributeError, match=pattern): + ParallelRunner(is_async=is_async).run(test_pipeline, catalog) + + def test_s3_glob_refresh(self): + spark_dataset = SparkDataset(filepath="s3a://bucket/data") + assert spark_dataset._glob_function.keywords == {"refresh": True} + + def test_dbfs_prefix_warning_no_databricks(self, caplog): + # test that warning is not raised when not on Databricks + expected_message = ( + "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" + ) + SparkDataset(filepath="my_project/data/02_intermediate/processed_data") + assert expected_message not in caplog.text + + @pytest.mark.parametrize( + "filepath,should_warn", + [ + ("/dbfs/my_project/data/02_intermediate/processed_data", False), + ("my_project/data/02_intermediate/processed_data", True), + ("s3://my_project/data/02_intermediate/processed_data", False), + ("/Volumes/catalog/schema/table", False), + ], + ) + def test_prefix_warning_on_databricks( + self, filepath, should_warn, monkeypatch, caplog + ): + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + + # Mock deployed_on_databricks to return True + monkeypatch.setattr( + kedro_datasets._utils.databricks_utils, + "deployed_on_databricks", + lambda: True, + ) + + SparkDataset(filepath=filepath) + + warning_msg = ( + "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" + ) + if should_warn: + assert warning_msg in caplog.text + else: + assert warning_msg not in caplog.text + + +class TestSparkDatasetVersionedLocal: + def test_no_version(self, versioned_dataset_local): + pattern = r"Did not find any versions for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_dataset_local.load() + + def test_load_latest(self, versioned_dataset_local, sample_spark_df): + versioned_dataset_local.save(sample_spark_df) + reloaded = versioned_dataset_local.load() + + assert reloaded.exceptAll(sample_spark_df).count() == 0 + + def test_load_exact(self, tmp_path, sample_spark_df): + ts = generate_timestamp() + ds_local = SparkDataset( + filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts) + ) + + ds_local.save(sample_spark_df) + reloaded = ds_local.load() + + assert reloaded.exceptAll(sample_spark_df).count() == 0 + + def test_save(self, versioned_dataset_local, version, tmp_path, sample_spark_df): + versioned_dataset_local.save(sample_spark_df) + assert (tmp_path / FILENAME / version.save / FILENAME).exists() + + def test_repr(self, versioned_dataset_local, tmp_path, version): + assert f"version=Version(load=None, save='{version.save}')" in str( + versioned_dataset_local + ) + + dataset_local = SparkDataset(filepath=(tmp_path / FILENAME).as_posix()) + assert "version=" not in str(dataset_local) + + def test_save_version_warning(self, tmp_path, sample_spark_df): + exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") + ds_local = SparkDataset( + filepath=(tmp_path / FILENAME).as_posix(), version=exact_version + ) + + pattern = ( + rf"Save version '{exact_version.save}' did not match load version " + rf"'{exact_version.load}' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + ds_local.save(sample_spark_df) + + def test_prevent_overwrite(self, tmp_path, version, sample_spark_df): + versioned_local = SparkDataset( + filepath=(tmp_path / FILENAME).as_posix(), + version=version, + # second save should fail even in overwrite mode + save_args={"mode": "overwrite"}, + ) + versioned_local.save(sample_spark_df) + + pattern = ( + r"Save path '.+' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\) must not exist " + r"if versioning is enabled" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_local.save(sample_spark_df) + + def test_versioning_existing_dataset( + self, versioned_dataset_local, sample_spark_df + ): + """Check behavior when attempting to save a versioned dataset on top of an + already existing (non-versioned) dataset. Note: because SparkDataset saves to a + directory even if non-versioned, an error is not expected.""" + spark_dataset = SparkDataset( + filepath=versioned_dataset_local._filepath.as_posix() + ) + spark_dataset.save(sample_spark_df) + assert spark_dataset.exists() + versioned_dataset_local.save(sample_spark_df) + assert versioned_dataset_local.exists() + + +@pytest.mark.skipif( + sys.platform.startswith("win"), reason="DBFS doesn't work on Windows" +) +class TestSparkDatasetVersionedDBFS: + def test_load_latest( + self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df + ): + mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") + mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] + + versioned_dataset_dbfs.save(sample_spark_df) + reloaded = versioned_dataset_dbfs.load() + + expected_calls = [ + mocker.call("/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME)) + ] + assert mocked_glob.call_args_list == expected_calls + + assert reloaded.exceptAll(sample_spark_df).count() == 0 + + def test_load_exact(self, tmp_path, sample_spark_df): + ts = generate_timestamp() + ds_dbfs = SparkDataset( + filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts) + ) + + ds_dbfs.save(sample_spark_df) + reloaded = ds_dbfs.load() + + assert reloaded.exceptAll(sample_spark_df).count() == 0 + + def test_save( + self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df + ): + mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") + mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] + + versioned_dataset_dbfs.save(sample_spark_df) + + mocked_glob.assert_called_once_with( + "/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME) + ) + assert (tmp_path / FILENAME / version.save / FILENAME).exists() + + def test_exists( + self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df + ): + mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") + mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] + + assert not versioned_dataset_dbfs.exists() + + versioned_dataset_dbfs.save(sample_spark_df) + assert versioned_dataset_dbfs.exists() + + expected_calls = [ + mocker.call("/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME)) + ] * 2 + assert mocked_glob.call_args_list == expected_calls + + @pytest.mark.parametrize("os_name", ["nt", "posix"]) + def test_regular_path_in_different_os(self, os_name, mocker): + """Check that class of filepath depends on OS for regular path.""" + mocker.patch("os.name", os_name) + dataset = SparkDataset(filepath="/some/path") + assert isinstance(dataset._filepath, PurePosixPath) + + @pytest.mark.parametrize("os_name", ["nt", "posix"]) + def test_dbfs_path_in_different_os(self, os_name, mocker): + """Check that class of filepath doesn't depend on OS if it references DBFS.""" + mocker.patch("os.name", os_name) + dataset = SparkDataset(filepath="/dbfs/some/path") + assert isinstance(dataset._filepath, PurePosixPath) + + +class TestSparkDatasetVersionedS3: + os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" + os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" + + @pytest.mark.xfail + def test_no_version(self, versioned_dataset_s3): + pattern = r"Did not find any versions for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" + with pytest.raises(DatasetError, match=pattern): + versioned_dataset_s3.load() + + def test_load_latest(self, mocker, versioned_dataset_s3): + mocker.patch.object(versioned_dataset_s3, "_get_spark") + mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") + mocked_glob.return_value = [ + "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") + ] + mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) + + # Mock the actual Spark read + mock_spark = mocker.MagicMock() + mocker.patch.object(versioned_dataset_s3, "_get_spark", return_value=mock_spark) + + versioned_dataset_s3.load() + + mocked_glob.assert_called_once_with(f"{BUCKET_NAME}/{FILENAME}/*/{FILENAME}") + mock_spark.read.load.assert_called_once() + + def test_load_exact(self, mocker): + ts = generate_timestamp() + ds_s3 = SparkDataset( + filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", + version=Version(ts, None), + ) + + mock_spark = mocker.MagicMock() + mocker.patch.object(ds_s3, "_get_spark", return_value=mock_spark) + + ds_s3.load() + + mock_spark.read.load.assert_called_once() + + def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker): + mocked_spark_df = mocker.Mock() + + ds_s3 = SparkDataset( + filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=version + ) + + # need resolve_load_version() call to return a load version that + # matches save version due to consistency check in versioned_dataset_s3.save() + mocker.patch.object(ds_s3, "resolve_load_version", return_value=version.save) + ds_s3.save(mocked_spark_df) + mocked_spark_df.write.save.assert_called_once() + + def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mocker): + exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") + ds_s3 = SparkDataset( + filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", + version=exact_version, + ) + mocked_spark_df = mocker.Mock() + + pattern = ( + rf"Save version '{exact_version.save}' did not match load version " + rf"'{exact_version.load}' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" + ) + with pytest.warns(UserWarning, match=pattern): + ds_s3.save(mocked_spark_df) + mocked_spark_df.write.save.assert_called_once() + + def test_prevent_overwrite(self, mocker, versioned_dataset_s3): + mocked_spark_df = mocker.Mock() + mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) + + pattern = ( + r"Save path '.+' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\) must not exist " + r"if versioning is enabled" + ) + with pytest.raises(DatasetError, match=pattern): + versioned_dataset_s3.save(mocked_spark_df) + + mocked_spark_df.write.save.assert_not_called() + + def test_repr(self, versioned_dataset_s3, version): + assert "filepath=" in str(versioned_dataset_s3) + assert f"version=Version(load=None, save='{version.save}')" in str( + versioned_dataset_s3 + ) + + dataset_s3 = SparkDataset(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") + assert "filepath=" in str(dataset_s3) + assert "version=" not in str(dataset_s3) + + +@pytest.fixture +def data_catalog(tmp_path): + source_path = Path(__file__).parent / "data/test.parquet" + spark_in = SparkDataset(filepath=source_path.as_posix()) + spark_out = SparkDataset(filepath=(tmp_path / "spark_data").as_posix()) + pickle_ds = PickleDataset(filepath=(tmp_path / "pickle/test.pkl").as_posix()) + + return DataCatalog( + {"spark_in": spark_in, "spark_out": spark_out, "pickle_ds": pickle_ds} + ) + + +@pytest.mark.parametrize("is_async", [False, True]) +class TestDataFlowSequentialRunner: + def test_spark_load_save(self, is_async, data_catalog): + """SparkDataset(load) -> node -> Spark (save).""" + test_pipeline = pipeline([node(identity, "spark_in", "spark_out")]) + SequentialRunner(is_async=is_async).run(test_pipeline, data_catalog) + + save_path = Path(data_catalog._datasets["spark_out"]._filepath.as_posix()) + files = list(save_path.glob("*.parquet")) + assert len(files) > 0 + + def test_spark_pickle(self, is_async, data_catalog): + """SparkDataset(load) -> node -> PickleDataset (save)""" + test_pipeline = pipeline([node(identity, "spark_in", "pickle_ds")]) + pattern = ".* was not serialised due to.*" + with pytest.raises(DatasetError, match=pattern): + SequentialRunner(is_async=is_async).run(test_pipeline, data_catalog) + + def test_spark_memory_spark(self, is_async, data_catalog): + """SparkDataset(load) -> node -> MemoryDataset (save and then load) -> + node -> SparkDataset (save)""" + test_pipeline = pipeline( + [ + node(identity, "spark_in", "memory_ds"), + node(identity, "memory_ds", "spark_out"), + ] + ) + SequentialRunner(is_async=is_async).run(test_pipeline, data_catalog) + + save_path = Path(data_catalog._datasets["spark_out"]._filepath.as_posix()) + files = list(save_path.glob("*.parquet")) + assert len(files) > 0 From 5ebd4f1d15d056b332b1ae1b4fdfd5e575ac64db Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 24 Sep 2025 15:39:13 +0100 Subject: [PATCH 10/17] Update test_spark_dataset_v2.py Signed-off-by: Sajid Alam --- .../tests/spark/test_spark_dataset_v2.py | 969 +++++++----------- 1 file changed, 393 insertions(+), 576 deletions(-) diff --git a/kedro-datasets/tests/spark/test_spark_dataset_v2.py b/kedro-datasets/tests/spark/test_spark_dataset_v2.py index 61c921aef..dc1d6e4f4 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset_v2.py +++ b/kedro-datasets/tests/spark/test_spark_dataset_v2.py @@ -1,20 +1,17 @@ +"""Tests for SparkDatasetV2.""" + import os -import re import sys import tempfile -from pathlib import Path, PurePosixPath +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch -import boto3 import pandas as pd import pytest from kedro.io import DataCatalog, Version from kedro.io.core import DatasetError, generate_timestamp -from kedro.io.data_catalog import SharedMemoryDataCatalog from kedro.pipeline import node, pipeline from kedro.runner import ParallelRunner, SequentialRunner -from moto import mock_aws -from packaging.version import Version as PackagingVersion -from pyspark import __version__ from pyspark.sql import SparkSession from pyspark.sql.functions import col from pyspark.sql.types import ( @@ -24,73 +21,55 @@ StructField, StructType, ) -from pyspark.sql.utils import AnalysisException -import kedro_datasets._utils.databricks_utils from kedro_datasets.pandas import CSVDataset, ParquetDataset -from kedro_datasets.pickle import PickleDataset -from kedro_datasets.spark import SparkDataset +from kedro_datasets.spark import SparkDatasetV2 -FOLDER_NAME = "fake_folder" +# Test constants FILENAME = "test.parquet" BUCKET_NAME = "test_bucket" SCHEMA_FILE_NAME = "schema.json" -AWS_CREDENTIALS = {"key": "FAKE_ACCESS_KEY", "secret": "FAKE_SECRET_KEY"} - -SPARK_VERSION = PackagingVersion(__version__) -@pytest.fixture -def sample_pandas_df() -> pd.DataFrame: - return pd.DataFrame( - {"Name": ["Alex", "Bob", "Clarke", "Dave"], "Age": [31, 12, 65, 29]} +@pytest.fixture(scope="module") +def spark_session(): + """Create a Spark session for testing.""" + spark = ( + SparkSession.builder.master("local[2]") + .appName("TestSparkDatasetV2") + .getOrCreate() ) + yield spark + spark.stop() @pytest.fixture -def version(): - load_version = None # use latest - save_version = generate_timestamp() # freeze save version - return Version(load_version, save_version) - - -@pytest.fixture -def versioned_dataset_local(tmp_path, version): - return SparkDataset(filepath=(tmp_path / FILENAME).as_posix(), version=version) - - -@pytest.fixture -def versioned_dataset_dbfs(tmp_path, version): - return SparkDataset( - filepath="/dbfs" + (tmp_path / FILENAME).as_posix(), version=version - ) - - -@pytest.fixture -def versioned_dataset_s3(version): - return SparkDataset( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", - version=version, - credentials=AWS_CREDENTIALS, - ) - - -@pytest.fixture -def sample_spark_df(): +def sample_spark_df(spark_session): + """Create a sample Spark DataFrame.""" schema = StructType( [ StructField("name", StringType(), True), StructField("age", IntegerType(), True), ] ) + data = [("Alice", 30), ("Bob", 25), ("Charlie", 35)] + return spark_session.createDataFrame(data, schema) - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - return SparkSession.builder.getOrCreate().createDataFrame(data, schema) +@pytest.fixture +def sample_pandas_df(): + """Create a sample pandas DataFrame.""" + return pd.DataFrame( + { + "name": ["Alice", "Bob", "Charlie"], + "age": [30, 25, 35], + } + ) @pytest.fixture -def sample_spark_df_schema() -> StructType: +def sample_schema(): + """Create a sample schema.""" return StructType( [ StructField("name", StringType(), True), @@ -100,632 +79,470 @@ def sample_spark_df_schema() -> StructType: ) -def identity(arg): - return arg # pragma: no cover +@pytest.fixture +def version(): + """Create a version for testing.""" + return Version(None, generate_timestamp()) -@pytest.fixture -def spark_in(tmp_path, sample_spark_df): - spark_in = SparkDataset(filepath=(tmp_path / "input").as_posix()) - spark_in.save(sample_spark_df) - return spark_in +class TestSparkDatasetV2Basic: + """Test basic functionality of SparkDatasetV2.""" + def test_load_save_parquet(self, tmp_path, sample_spark_df): + """Test basic load and save with parquet format.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) -@pytest.fixture -def mocked_s3_bucket(): - """Create a bucket for testing using moto.""" - with mock_aws(): - conn = boto3.client( - "s3", - aws_access_key_id=AWS_CREDENTIALS["key"], - aws_secret_access_key=AWS_CREDENTIALS["secret"], - ) - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn + # Save + dataset.save(sample_spark_df) + assert Path(filepath).exists() + # Load + loaded_df = dataset.load() + assert loaded_df.count() == sample_spark_df.count() + assert set(loaded_df.columns) == set(sample_spark_df.columns) -@pytest.fixture -def mocked_s3_schema(tmp_path, mocked_s3_bucket, sample_spark_df_schema: StructType): - """Creates schema file and adds it to mocked S3 bucket.""" - temporary_path = tmp_path / SCHEMA_FILE_NAME - temporary_path.write_text(sample_spark_df_schema.json(), encoding="utf-8") + def test_load_save_csv(self, tmp_path, sample_spark_df): + """Test load and save with CSV format.""" + filepath = str(tmp_path / "test.csv") + dataset = SparkDatasetV2( + filepath=filepath, + file_format="csv", + save_args={"header": True}, + load_args={"header": True, "inferSchema": True}, + ) - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=SCHEMA_FILE_NAME, Body=temporary_path.read_bytes() - ) - return f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}" + dataset.save(sample_spark_df) + loaded_df = dataset.load() + assert loaded_df.count() == sample_spark_df.count() + assert set(loaded_df.columns) == set(sample_spark_df.columns) -class TestSparkDataset: - def test_load_parquet(self, tmp_path, sample_pandas_df): - temp_path = (tmp_path / "data").as_posix() - local_parquet_set = ParquetDataset(filepath=temp_path) - local_parquet_set.save(sample_pandas_df) - spark_dataset = SparkDataset(filepath=temp_path) - spark_df = spark_dataset.load() - assert spark_df.count() == 4 - - def test_save_parquet(self, tmp_path, sample_spark_df): - # To cross check the correct Spark save operation we save to - # a single spark partition and retrieve it with Kedro - # ParquetDataset - temp_dir = Path(str(tmp_path / "test_data")) - spark_dataset = SparkDataset( - filepath=temp_dir.as_posix(), save_args={"compression": "none"} - ) - spark_df = sample_spark_df.coalesce(1) - spark_dataset.save(spark_df) + def test_load_save_json(self, tmp_path, sample_spark_df): + """Test load and save with JSON format.""" + filepath = str(tmp_path / "test.json") + dataset = SparkDatasetV2(filepath=filepath, file_format="json") - single_parquet = [ - f for f in temp_dir.iterdir() if f.is_file() and f.name.startswith("part") - ][0] + dataset.save(sample_spark_df) + loaded_df = dataset.load() - local_parquet_dataset = ParquetDataset(filepath=single_parquet.as_posix()) + assert loaded_df.count() == sample_spark_df.count() - pandas_df = local_parquet_dataset.load() + def test_save_modes(self, tmp_path, sample_spark_df): + """Test different save modes.""" + filepath = str(tmp_path / "test.parquet") - assert pandas_df[pandas_df["name"] == "Bob"]["age"].iloc[0] == 12 + # Test overwrite mode + dataset = SparkDatasetV2(filepath=filepath, save_args={"mode": "overwrite"}) + dataset.save(sample_spark_df) + dataset.save(sample_spark_df) # Should not fail - def test_load_options_csv(self, tmp_path, sample_pandas_df): - filepath = (tmp_path / "data").as_posix() - local_csv_dataset = CSVDataset(filepath=filepath) - local_csv_dataset.save(sample_pandas_df) - spark_dataset = SparkDataset( - filepath=filepath, file_format="csv", load_args={"header": True} + # Test append mode + dataset_append = SparkDatasetV2( + filepath=str(tmp_path / "test_append.parquet"), save_args={"mode": "append"} ) - spark_df = spark_dataset.load() - assert spark_df.filter(col("Name") == "Alex").count() == 1 - - def test_load_options_schema_ddl_string( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - local_csv_dataset = CSVDataset(filepath=filepath) - local_csv_dataset.save(sample_pandas_df) - spark_dataset = SparkDataset( + dataset_append.save(sample_spark_df) + dataset_append.save(sample_spark_df) + loaded = dataset_append.load() + assert loaded.count() == sample_spark_df.count() * 2 + + def test_partitioning(self, tmp_path, sample_spark_df): + """Test data partitioning.""" + filepath = str(tmp_path / "test_partitioned.parquet") + dataset = SparkDatasetV2(filepath=filepath, save_args={"partitionBy": ["name"]}) + + dataset.save(sample_spark_df) + + # Check partition directories exist + base_path = Path(filepath) + partitions = [d for d in base_path.iterdir() if d.is_dir()] + assert len(partitions) > 0 + assert any("name=" in d.name for d in partitions) + + def test_exists(self, tmp_path, sample_spark_df): + """Test exists functionality.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + + assert not dataset.exists() + dataset.save(sample_spark_df) + assert dataset.exists() + + def test_describe(self, tmp_path): + """Test _describe method.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2( filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": "name STRING, age INT, height FLOAT"}, + file_format="parquet", + load_args={"mergeSchema": True}, + save_args={"compression": "snappy"}, ) - spark_df = spark_dataset.load() - assert spark_df.schema == sample_spark_df_schema - def test_load_options_schema_obj( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - local_csv_dataset = CSVDataset(filepath=filepath) - local_csv_dataset.save(sample_pandas_df) + description = dataset._describe() + assert description["file_format"] == "parquet" + assert description["load_args"] == {"mergeSchema": True} + assert description["save_args"] == {"compression": "snappy"} - spark_dataset = SparkDataset( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": sample_spark_df_schema}, - ) + def test_str_representation(self, tmp_path): + """Test string representation.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) - spark_df = spark_dataset.load() - assert spark_df.schema == sample_spark_df_schema - - def test_load_options_schema_path( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() - local_csv_dataset = CSVDataset(filepath=filepath) - local_csv_dataset.save(sample_pandas_df) - Path(schemapath).write_text(sample_spark_df_schema.json(), encoding="utf-8") - - spark_dataset = SparkDataset( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": {"filepath": schemapath}}, - ) + assert "SparkDatasetV2" in str(dataset) + assert filepath in str(dataset) - spark_df = spark_dataset.load() - assert spark_df.schema == sample_spark_df_schema - @pytest.mark.usefixtures("mocked_s3_schema") - def test_load_options_schema_path_with_credentials( - self, tmp_path, sample_pandas_df, sample_spark_df_schema - ): - filepath = (tmp_path / "data").as_posix() - local_csv_dataset = CSVDataset(filepath=filepath) - local_csv_dataset.save(sample_pandas_df) +class TestSparkDatasetV2Schema: + """Test schema handling in SparkDatasetV2.""" - spark_dataset = SparkDataset( - filepath=filepath, + def test_schema_from_dict(self, tmp_path, sample_pandas_df, sample_schema): + """Test loading schema from dict.""" + # Save schema to file + schema_path = tmp_path / "schema.json" + schema_path.write_text(sample_schema.json()) + + # Save CSV data + csv_path = str(tmp_path / "test.csv") + sample_pandas_df.to_csv(csv_path, index=False) + + # Load with schema + dataset = SparkDatasetV2( + filepath=csv_path, file_format="csv", - load_args={ - "header": True, - "schema": { - "filepath": f"s3://{BUCKET_NAME}/{SCHEMA_FILE_NAME}", - "credentials": AWS_CREDENTIALS, - }, - }, + load_args={"header": True, "schema": {"filepath": str(schema_path)}}, ) - spark_df = spark_dataset.load() - assert spark_df.schema == sample_spark_df_schema + loaded_df = dataset.load() + assert loaded_df.schema == sample_schema - def test_load_options_invalid_schema_file(self, tmp_path): - filepath = (tmp_path / "data").as_posix() - schemapath = (tmp_path / SCHEMA_FILE_NAME).as_posix() - Path(schemapath).write_text("dummy", encoding="utf-8") + def test_schema_invalid_filepath(self, tmp_path): + """Test error when schema filepath is invalid.""" + csv_path = str(tmp_path / "test.csv") + schema_path = tmp_path / "bad_schema.json" + schema_path.write_text("invalid json {") - pattern = f"Failed to load schema from {schemapath}" - - with pytest.raises(DatasetError, match=pattern): - SparkDataset( - filepath=filepath, + with pytest.raises(DatasetError, match="Failed to load schema"): + SparkDatasetV2( + filepath=csv_path, file_format="csv", - load_args={"header": True, "schema": {"filepath": schemapath}}, + load_args={"schema": {"filepath": str(schema_path)}}, ) - def test_load_options_invalid_schema(self, tmp_path): - filepath = (tmp_path / "data").as_posix() - - pattern = "Schema dict must have 'filepath'" + def test_schema_missing_filepath(self, tmp_path): + """Test error when schema dict missing filepath.""" + csv_path = str(tmp_path / "test.csv") - with pytest.raises(DatasetError, match=pattern): - SparkDataset( - filepath=filepath, - file_format="csv", - load_args={"header": True, "schema": {}}, + with pytest.raises(DatasetError, match="Schema dict must have 'filepath'"): + SparkDatasetV2( + filepath=csv_path, file_format="csv", load_args={"schema": {}} ) - def test_save_options_csv(self, tmp_path, sample_spark_df): - # To cross check the correct Spark save operation we save to - # a single spark partition with csv format and retrieve it with Kedro - # CSVDataset - temp_dir = Path(str(tmp_path / "test_data")) - spark_dataset = SparkDataset( - filepath=temp_dir.as_posix(), - file_format="csv", - save_args={"sep": "|", "header": True}, - ) - spark_df = sample_spark_df.coalesce(1) - spark_dataset.save(spark_df) - single_csv_file = [ - f for f in temp_dir.iterdir() if f.is_file() and f.suffix == ".csv" - ][0] +class TestSparkDatasetV2PathHandling: + """Test path handling in SparkDatasetV2.""" - csv_local_dataset = CSVDataset( - filepath=single_csv_file.as_posix(), load_args={"sep": "|"} - ) - pandas_df = csv_local_dataset.load() + def test_local_path(self, tmp_path): + """Test local path handling.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) - assert pandas_df[pandas_df["name"] == "Alex"]["age"][0] == 31 + assert dataset.protocol == "" + assert dataset._spark_path == filepath - def test_str_representation(self): - with tempfile.NamedTemporaryFile() as temp_data_file: - filepath = Path(temp_data_file.name).as_posix() - spark_dataset = SparkDataset( - filepath=filepath, file_format="csv", load_args={"header": True} - ) - assert "kedro_datasets.spark.spark_dataset.SparkDataset" in str( - spark_dataset - ) - assert f"filepath='{filepath}" in str(spark_dataset) + def test_s3_path_normalization(self): + """Test S3 path normalization to s3a://.""" + # All S3 variants should normalize to s3a:// + for prefix in ["s3://", "s3n://", "s3a://"]: + filepath = f"{prefix}bucket/path/data.parquet" + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path.startswith("s3a://") - def test_save_overwrite_fail(self, tmp_path, sample_spark_df): - # Writes a data frame twice and expects it to fail. - filepath = (tmp_path / "test_data").as_posix() - spark_dataset = SparkDataset(filepath=filepath) - spark_dataset.save(sample_spark_df) + @pytest.mark.skipif( + "DATABRICKS_RUNTIME_VERSION" not in os.environ, + reason="Not running on Databricks", + ) + def test_dbfs_path_on_databricks(self): + """Test DBFS path handling on Databricks.""" + filepath = "/dbfs/path/to/data.parquet" + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path == "dbfs:/path/to/data.parquet" + + def test_dbfs_path_not_on_databricks(self, monkeypatch): + """Test DBFS path handling when not on Databricks.""" + # Ensure we're not on Databricks + monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) + + filepath = "/dbfs/path/to/data.parquet" + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path == filepath + + def test_other_protocols(self): + """Test other protocol handling.""" + protocols = { + "gs://bucket/path": "gs://", + "abfs://container@account.dfs.core.windows.net/path": "abfs://", + } + + for filepath, expected_prefix in protocols.items(): + dataset = SparkDatasetV2(filepath=filepath) + assert dataset._spark_path.startswith(expected_prefix) + + +class TestSparkDatasetV2ErrorMessages: + """Test improved error messages in SparkDatasetV2.""" + + def test_missing_s3fs_error(self, mocker): + """Test helpful error for missing s3fs.""" + import_error = ImportError("No module named 's3fs'") + mocker.patch("fsspec.filesystem", side_effect=import_error) + + with pytest.raises( + ImportError, match="pip install 'kedro-datasets\\[spark-s3\\]'" + ): + SparkDatasetV2(filepath="s3://bucket/data.parquet") + + def test_missing_gcsfs_error(self, mocker): + """Test helpful error for missing gcsfs.""" + import_error = ImportError("No module named 'gcsfs'") + mocker.patch("fsspec.filesystem", side_effect=import_error) + + with pytest.raises(ImportError, match="pip install gcsfs"): + SparkDatasetV2(filepath="gs://bucket/data.parquet") + + def test_missing_pyspark_databricks(self, mocker, monkeypatch): + """Test helpful error for PySpark on Databricks.""" + monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") - with pytest.raises(DatasetError): - spark_dataset.save(sample_spark_df) + dataset = SparkDatasetV2(filepath="test.parquet") + mocker.patch.object( + dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") + ) - def test_save_overwrite_mode(self, tmp_path, sample_spark_df): - # Writes a data frame in overwrite mode. - filepath = (tmp_path / "test_data").as_posix() - spark_dataset = SparkDataset(filepath=filepath, save_args={"mode": "overwrite"}) + with pytest.raises(ImportError, match="databricks-connect"): + dataset.load() - spark_dataset.save(sample_spark_df) - spark_dataset.save(sample_spark_df) + def test_missing_pyspark_emr(self, mocker, monkeypatch): + """Test helpful error for PySpark on EMR.""" + monkeypatch.setenv("EMR_RELEASE_LABEL", "emr-7.0.0") - @pytest.mark.parametrize("mode", ["merge", "delete", "update"]) - def test_file_format_delta_and_unsupported_mode(self, tmp_path, mode): - filepath = (tmp_path / "test_data").as_posix() - pattern = f"Delta format doesn't support mode '{mode}'. " f"Use one of" + dataset = SparkDatasetV2(filepath="test.parquet") + mocker.patch.object( + dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") + ) - with pytest.raises(DatasetError, match=pattern): - _ = SparkDataset( - filepath=filepath, file_format="delta", save_args={"mode": mode} - ) + with pytest.raises(ImportError, match="should be pre-installed on EMR"): + dataset.load() - def test_save_partition(self, tmp_path, sample_spark_df): - # To verify partitioning this test will partition the data by one - # of the columns and then check whether partitioned column is added - # to the save path + def test_missing_pyspark_local(self, mocker, monkeypatch): + """Test helpful error for PySpark locally.""" + monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) + monkeypatch.delenv("EMR_RELEASE_LABEL", raising=False) - filepath = Path(str(tmp_path / "test_data")) - spark_dataset = SparkDataset( - filepath=filepath.as_posix(), - save_args={"mode": "overwrite", "partitionBy": ["name"]}, + dataset = SparkDatasetV2(filepath="test.parquet") + mocker.patch.object( + dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") ) - spark_dataset.save(sample_spark_df) - - expected_path = filepath / "name=Alex" + with pytest.raises( + ImportError, match="pip install 'kedro-datasets\\[spark-local\\]'" + ): + dataset.load() - assert expected_path.exists() - @pytest.mark.parametrize("file_format", ["csv", "parquet", "delta"]) - def test_exists(self, file_format, tmp_path, sample_spark_df): - filepath = (tmp_path / "test_data").as_posix() - spark_dataset = SparkDataset(filepath=filepath, file_format=file_format) +class TestSparkDatasetV2Delta: + """Test Delta format handling in SparkDatasetV2.""" - assert not spark_dataset.exists() + @pytest.mark.parametrize("mode", ["merge", "update", "delete"]) + def test_delta_unsupported_modes(self, tmp_path, mode): + """Test that unsupported Delta modes raise errors.""" + filepath = str(tmp_path / "test.delta") - spark_dataset.save(sample_spark_df) - assert spark_dataset.exists() - - def test_exists_raises_error(self, mocker): - # exists should raise all errors except for - # AnalysisExceptions clearly indicating a missing file - spark_dataset = SparkDataset(filepath="") - if SPARK_VERSION >= PackagingVersion("3.4.0"): - mocker.patch.object( - spark_dataset, - "_get_spark", - side_effect=AnalysisException("Other Exception"), - ) - else: - mocker.patch.object( - spark_dataset, - "_get_spark", - side_effect=AnalysisException("Other Exception", []), + with pytest.raises( + DatasetError, match=f"Delta format doesn't support mode '{mode}'" + ): + SparkDatasetV2( + filepath=filepath, file_format="delta", save_args={"mode": mode} ) - with pytest.raises(DatasetError, match="Other Exception"): - spark_dataset.exists() - - @pytest.mark.parametrize("is_async", [False, True]) - def test_parallel_runner(self, is_async, spark_in): - """Test ParallelRunner with SparkDataset fails.""" - catalog = SharedMemoryDataCatalog({"spark_in": spark_in}) - test_pipeline = pipeline([node(identity, "spark_in", "spark_out")]) - pattern = ( - r"The following datasets cannot be used with " - r"multiprocessing: \['spark_in'\]" - ) - with pytest.raises(AttributeError, match=pattern): - ParallelRunner(is_async=is_async).run(test_pipeline, catalog) - - def test_s3_glob_refresh(self): - spark_dataset = SparkDataset(filepath="s3a://bucket/data") - assert spark_dataset._glob_function.keywords == {"refresh": True} - - def test_dbfs_prefix_warning_no_databricks(self, caplog): - # test that warning is not raised when not on Databricks - expected_message = ( - "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" - ) - SparkDataset(filepath="my_project/data/02_intermediate/processed_data") - assert expected_message not in caplog.text @pytest.mark.parametrize( - "filepath,should_warn", - [ - ("/dbfs/my_project/data/02_intermediate/processed_data", False), - ("my_project/data/02_intermediate/processed_data", True), - ("s3://my_project/data/02_intermediate/processed_data", False), - ("/Volumes/catalog/schema/table", False), - ], + "mode", ["append", "overwrite", "error", "errorifexists", "ignore"] ) - def test_prefix_warning_on_databricks( - self, filepath, should_warn, monkeypatch, caplog - ): - monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") + def test_delta_supported_modes(self, tmp_path, mode): + """Test that supported Delta modes work.""" + filepath = str(tmp_path / "test.delta") - # Mock deployed_on_databricks to return True - monkeypatch.setattr( - kedro_datasets._utils.databricks_utils, - "deployed_on_databricks", - lambda: True, + # Should not raise + dataset = SparkDatasetV2( + filepath=filepath, file_format="delta", save_args={"mode": mode} ) + assert dataset.file_format == "delta" - SparkDataset(filepath=filepath) - warning_msg = ( - "Using SparkDataset on Databricks without the `/dbfs/` or `/Volumes` prefix" - ) - if should_warn: - assert warning_msg in caplog.text - else: - assert warning_msg not in caplog.text +class TestSparkDatasetV2Versioning: + """Test versioning functionality in SparkDatasetV2.""" + def test_versioned_save_and_load(self, tmp_path, sample_spark_df, version): + """Test versioned save and load.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath, version=version) -class TestSparkDatasetVersionedLocal: - def test_no_version(self, versioned_dataset_local): - pattern = r"Did not find any versions for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_dataset_local.load() + # Save versioned + dataset.save(sample_spark_df) - def test_load_latest(self, versioned_dataset_local, sample_spark_df): - versioned_dataset_local.save(sample_spark_df) - reloaded = versioned_dataset_local.load() + # Check versioned path exists + versioned_path = tmp_path / "test.parquet" / version.save / "test.parquet" + assert versioned_path.exists() - assert reloaded.exceptAll(sample_spark_df).count() == 0 + # Load versioned + loaded_df = dataset.load() + assert loaded_df.count() == sample_spark_df.count() - def test_load_exact(self, tmp_path, sample_spark_df): - ts = generate_timestamp() - ds_local = SparkDataset( - filepath=(tmp_path / FILENAME).as_posix(), version=Version(ts, ts) - ) + def test_no_version_error(self, tmp_path): + """Test error when no versions exist.""" + filepath = str(tmp_path / "test.parquet") + version = Version(None, None) # Load latest + dataset = SparkDatasetV2(filepath=filepath, version=version) - ds_local.save(sample_spark_df) - reloaded = ds_local.load() + with pytest.raises(DatasetError, match="Did not find any versions"): + dataset.load() - assert reloaded.exceptAll(sample_spark_df).count() == 0 + def test_version_str_representation(self, tmp_path, version): + """Test version in string representation.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath, version=version) - def test_save(self, versioned_dataset_local, version, tmp_path, sample_spark_df): - versioned_dataset_local.save(sample_spark_df) - assert (tmp_path / FILENAME / version.save / FILENAME).exists() + assert "version=" in str(dataset._describe()) - def test_repr(self, versioned_dataset_local, tmp_path, version): - assert f"version=Version(load=None, save='{version.save}')" in str( - versioned_dataset_local - ) - dataset_local = SparkDataset(filepath=(tmp_path / FILENAME).as_posix()) - assert "version=" not in str(dataset_local) +class TestSparkDatasetV2Integration: + """Integration tests for SparkDatasetV2.""" - def test_save_version_warning(self, tmp_path, sample_spark_df): - exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - ds_local = SparkDataset( - filepath=(tmp_path / FILENAME).as_posix(), version=exact_version - ) + def test_parallel_runner_restriction(self, tmp_path, sample_spark_df): + """Test that ParallelRunner is restricted.""" + filepath = str(tmp_path / "test.parquet") + dataset = SparkDatasetV2(filepath=filepath) + dataset.save(sample_spark_df) - pattern = ( - rf"Save version '{exact_version.save}' did not match load version " - rf"'{exact_version.load}' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - ds_local.save(sample_spark_df) - - def test_prevent_overwrite(self, tmp_path, version, sample_spark_df): - versioned_local = SparkDataset( - filepath=(tmp_path / FILENAME).as_posix(), - version=version, - # second save should fail even in overwrite mode - save_args={"mode": "overwrite"}, - ) - versioned_local.save(sample_spark_df) + catalog = DataCatalog({"spark_data": dataset}) + test_pipeline = pipeline([node(lambda x: x, "spark_data", "output")]) - pattern = ( - r"Save path '.+' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\) must not exist " - r"if versioning is enabled" - ) - with pytest.raises(DatasetError, match=pattern): - versioned_local.save(sample_spark_df) - - def test_versioning_existing_dataset( - self, versioned_dataset_local, sample_spark_df - ): - """Check behavior when attempting to save a versioned dataset on top of an - already existing (non-versioned) dataset. Note: because SparkDataset saves to a - directory even if non-versioned, an error is not expected.""" - spark_dataset = SparkDataset( - filepath=versioned_dataset_local._filepath.as_posix() - ) - spark_dataset.save(sample_spark_df) - assert spark_dataset.exists() - versioned_dataset_local.save(sample_spark_df) - assert versioned_dataset_local.exists() + with pytest.raises(AttributeError, match="cannot be used with multiprocessing"): + ParallelRunner().run(test_pipeline, catalog) + def test_sequential_runner(self, tmp_path, sample_spark_df): + """Test that SequentialRunner works.""" + filepath_in = str(tmp_path / "input.parquet") + filepath_out = str(tmp_path / "output.parquet") -@pytest.mark.skipif( - sys.platform.startswith("win"), reason="DBFS doesn't work on Windows" -) -class TestSparkDatasetVersionedDBFS: - def test_load_latest( - self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df - ): - mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") - mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] - - versioned_dataset_dbfs.save(sample_spark_df) - reloaded = versioned_dataset_dbfs.load() - - expected_calls = [ - mocker.call("/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME)) - ] - assert mocked_glob.call_args_list == expected_calls + dataset_in = SparkDatasetV2(filepath=filepath_in) + dataset_out = SparkDatasetV2(filepath=filepath_out) + dataset_in.save(sample_spark_df) - assert reloaded.exceptAll(sample_spark_df).count() == 0 + catalog = DataCatalog({"input": dataset_in, "output": dataset_out}) - def test_load_exact(self, tmp_path, sample_spark_df): - ts = generate_timestamp() - ds_dbfs = SparkDataset( - filepath="/dbfs" + str(tmp_path / FILENAME), version=Version(ts, ts) - ) + test_pipeline = pipeline([node(lambda x: x, "input", "output")]) + SequentialRunner().run(test_pipeline, catalog) + + assert Path(filepath_out).exists() - ds_dbfs.save(sample_spark_df) - reloaded = ds_dbfs.load() + def test_interop_with_pandas_dataset(self, tmp_path, sample_pandas_df): + """Test interoperability with pandas datasets.""" + # Save with pandas + pandas_path = str(tmp_path / "pandas.parquet") + pandas_dataset = ParquetDataset(filepath=pandas_path) + pandas_dataset.save(sample_pandas_df) - assert reloaded.exceptAll(sample_spark_df).count() == 0 + # Load with SparkDatasetV2 + spark_dataset = SparkDatasetV2(filepath=pandas_path) + spark_df = spark_dataset.load() - def test_save( - self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df - ): - mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") - mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] + assert spark_df.count() == len(sample_pandas_df) + assert set(spark_df.columns) == set(sample_pandas_df.columns) - versioned_dataset_dbfs.save(sample_spark_df) - mocked_glob.assert_called_once_with( - "/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME) - ) - assert (tmp_path / FILENAME / version.save / FILENAME).exists() - - def test_exists( - self, mocker, versioned_dataset_dbfs, version, tmp_path, sample_spark_df - ): - mocked_glob = mocker.patch.object(versioned_dataset_dbfs, "_glob_function") - mocked_glob.return_value = [str(tmp_path / FILENAME / version.save / FILENAME)] - - assert not versioned_dataset_dbfs.exists() - - versioned_dataset_dbfs.save(sample_spark_df) - assert versioned_dataset_dbfs.exists() - - expected_calls = [ - mocker.call("/dbfs" + str(tmp_path / FILENAME / "*" / FILENAME)) - ] * 2 - assert mocked_glob.call_args_list == expected_calls - - @pytest.mark.parametrize("os_name", ["nt", "posix"]) - def test_regular_path_in_different_os(self, os_name, mocker): - """Check that class of filepath depends on OS for regular path.""" - mocker.patch("os.name", os_name) - dataset = SparkDataset(filepath="/some/path") - assert isinstance(dataset._filepath, PurePosixPath) - - @pytest.mark.parametrize("os_name", ["nt", "posix"]) - def test_dbfs_path_in_different_os(self, os_name, mocker): - """Check that class of filepath doesn't depend on OS if it references DBFS.""" - mocker.patch("os.name", os_name) - dataset = SparkDataset(filepath="/dbfs/some/path") - assert isinstance(dataset._filepath, PurePosixPath) - - -class TestSparkDatasetVersionedS3: - os.environ["AWS_ACCESS_KEY_ID"] = "FAKE_ACCESS_KEY" - os.environ["AWS_SECRET_ACCESS_KEY"] = "FAKE_SECRET_KEY" - - @pytest.mark.xfail - def test_no_version(self, versioned_dataset_s3): - pattern = r"Did not find any versions for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" - with pytest.raises(DatasetError, match=pattern): - versioned_dataset_s3.load() - - def test_load_latest(self, mocker, versioned_dataset_s3): - mocker.patch.object(versioned_dataset_s3, "_get_spark") - mocked_glob = mocker.patch.object(versioned_dataset_s3, "_glob_function") - mocked_glob.return_value = [ - "{b}/{f}/{v}/{f}".format(b=BUCKET_NAME, f=FILENAME, v="mocked_version") - ] - mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) +class TestSparkDatasetV2Compatibility: + """Test compatibility between V1 and V2.""" - # Mock the actual Spark read - mock_spark = mocker.MagicMock() - mocker.patch.object(versioned_dataset_s3, "_get_spark", return_value=mock_spark) + def test_v1_v2_coexistence(self, tmp_path, sample_spark_df): + """Test that V1 and V2 can coexist in the same catalog.""" + from kedro_datasets.spark import SparkDataset # noqa: PLC0415 - versioned_dataset_s3.load() + filepath_v1 = str(tmp_path / "v1.parquet") + filepath_v2 = str(tmp_path / "v2.parquet") - mocked_glob.assert_called_once_with(f"{BUCKET_NAME}/{FILENAME}/*/{FILENAME}") - mock_spark.read.load.assert_called_once() + dataset_v1 = SparkDataset(filepath=filepath_v1) + dataset_v2 = SparkDatasetV2(filepath=filepath_v2) - def test_load_exact(self, mocker): - ts = generate_timestamp() - ds_s3 = SparkDataset( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", - version=Version(ts, None), + catalog = DataCatalog( + { + "v1_data": dataset_v1, + "v2_data": dataset_v2, + } ) - mock_spark = mocker.MagicMock() - mocker.patch.object(ds_s3, "_get_spark", return_value=mock_spark) + # Both should work + dataset_v1.save(sample_spark_df) + dataset_v2.save(sample_spark_df) - ds_s3.load() + assert catalog.exists("v1_data") + assert catalog.exists("v2_data") - mock_spark.read.load.assert_called_once() + def test_v2_reads_v1_data(self, tmp_path, sample_spark_df): + """Test that V2 can read data saved by V1.""" + from kedro_datasets.spark import SparkDataset # noqa: PLC0415 - def test_save(self, mocked_s3_schema, versioned_dataset_s3, version, mocker): - mocked_spark_df = mocker.Mock() + filepath = str(tmp_path / "shared.parquet") - ds_s3 = SparkDataset( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", version=version - ) + # V1 saves + dataset_v1 = SparkDataset(filepath=filepath) + dataset_v1.save(sample_spark_df) - # need resolve_load_version() call to return a load version that - # matches save version due to consistency check in versioned_dataset_s3.save() - mocker.patch.object(ds_s3, "resolve_load_version", return_value=version.save) - ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once() - - def test_save_version_warning(self, mocked_s3_schema, versioned_dataset_s3, mocker): - exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z") - ds_s3 = SparkDataset( - filepath=f"s3a://{BUCKET_NAME}/{FILENAME}", - version=exact_version, - ) - mocked_spark_df = mocker.Mock() + # V2 loads + dataset_v2 = SparkDatasetV2(filepath=filepath) + loaded_df = dataset_v2.load() - pattern = ( - rf"Save version '{exact_version.save}' did not match load version " - rf"'{exact_version.load}' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\)" - ) - with pytest.warns(UserWarning, match=pattern): - ds_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_called_once() + assert loaded_df.count() == sample_spark_df.count() - def test_prevent_overwrite(self, mocker, versioned_dataset_s3): - mocked_spark_df = mocker.Mock() - mocker.patch.object(versioned_dataset_s3, "_exists_function", return_value=True) - pattern = ( - r"Save path '.+' for kedro_datasets.spark.spark_dataset.SparkDataset\(.+\) must not exist " - r"if versioning is enabled" +# Fixtures for mocking cloud storage +@pytest.fixture +def mock_s3_filesystem(): + """Mock S3 filesystem.""" + with patch("fsspec.filesystem") as mock_fs: + mock_filesystem = MagicMock() + mock_filesystem.exists.return_value = True + mock_filesystem.glob.return_value = [] + mock_fs.return_value = mock_filesystem + yield mock_filesystem + + +class TestSparkDatasetV2CloudStorage: + """Test cloud storage handling.""" + + def test_s3_credentials(self, mock_s3_filesystem): + """Test S3 with credentials.""" + dataset = SparkDatasetV2( + filepath="s3://bucket/data.parquet", + credentials={"key": "test_key", "secret": "test_secret"}, ) - with pytest.raises(DatasetError, match=pattern): - versioned_dataset_s3.save(mocked_spark_df) - mocked_spark_df.write.save.assert_not_called() + # Verify s3a:// normalization + assert dataset._spark_path.startswith("s3a://") - def test_repr(self, versioned_dataset_s3, version): - assert "filepath=" in str(versioned_dataset_s3) - assert f"version=Version(load=None, save='{version.save}')" in str( - versioned_dataset_s3 + def test_gcs_handling(self, mock_s3_filesystem): + """Test GCS handling.""" + dataset = SparkDatasetV2( + filepath="gs://bucket/data.parquet", + credentials={"token": "path/to/token.json"}, ) - dataset_s3 = SparkDataset(filepath=f"s3a://{BUCKET_NAME}/{FILENAME}") - assert "filepath=" in str(dataset_s3) - assert "version=" not in str(dataset_s3) - - -@pytest.fixture -def data_catalog(tmp_path): - source_path = Path(__file__).parent / "data/test.parquet" - spark_in = SparkDataset(filepath=source_path.as_posix()) - spark_out = SparkDataset(filepath=(tmp_path / "spark_data").as_posix()) - pickle_ds = PickleDataset(filepath=(tmp_path / "pickle/test.pkl").as_posix()) - - return DataCatalog( - {"spark_in": spark_in, "spark_out": spark_out, "pickle_ds": pickle_ds} - ) - + assert dataset._spark_path.startswith("gs://") -@pytest.mark.parametrize("is_async", [False, True]) -class TestDataFlowSequentialRunner: - def test_spark_load_save(self, is_async, data_catalog): - """SparkDataset(load) -> node -> Spark (save).""" - test_pipeline = pipeline([node(identity, "spark_in", "spark_out")]) - SequentialRunner(is_async=is_async).run(test_pipeline, data_catalog) - - save_path = Path(data_catalog._datasets["spark_out"]._filepath.as_posix()) - files = list(save_path.glob("*.parquet")) - assert len(files) > 0 - - def test_spark_pickle(self, is_async, data_catalog): - """SparkDataset(load) -> node -> PickleDataset (save)""" - test_pipeline = pipeline([node(identity, "spark_in", "pickle_ds")]) - pattern = ".* was not serialised due to.*" - with pytest.raises(DatasetError, match=pattern): - SequentialRunner(is_async=is_async).run(test_pipeline, data_catalog) - - def test_spark_memory_spark(self, is_async, data_catalog): - """SparkDataset(load) -> node -> MemoryDataset (save and then load) -> - node -> SparkDataset (save)""" - test_pipeline = pipeline( - [ - node(identity, "spark_in", "memory_ds"), - node(identity, "memory_ds", "spark_out"), - ] + def test_azure_handling(self, mock_s3_filesystem): + """Test Azure Blob Storage handling.""" + dataset = SparkDatasetV2( + filepath="abfs://container@account.dfs.core.windows.net/data.parquet", + credentials={"account_key": "test_key"}, ) - SequentialRunner(is_async=is_async).run(test_pipeline, data_catalog) - save_path = Path(data_catalog._datasets["spark_out"]._filepath.as_posix()) - files = list(save_path.glob("*.parquet")) - assert len(files) > 0 + assert dataset._spark_path.startswith("abfs://") From d06925641d1a0afaf2cf2451ed6d32a634a80081 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 29 Sep 2025 14:13:33 +0100 Subject: [PATCH 11/17] changes based on feedback Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset_v2.py | 309 +++++++++++++----- 1 file changed, 228 insertions(+), 81 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py index dc225de69..66afa63dd 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -4,8 +4,11 @@ from __future__ import annotations +import json import logging import os +import warnings +from functools import partial from pathlib import PurePosixPath from typing import TYPE_CHECKING, Any @@ -98,47 +101,113 @@ def __init__( # noqa: PLR0913 metadata: dict[str, Any] | None = None, ): self.file_format = file_format - self.load_args = load_args or {} - self.save_args = save_args or {} + self.load_args = {**self.DEFAULT_LOAD_ARGS, **(load_args or {})} + self.save_args = {**self.DEFAULT_SAVE_ARGS, **(save_args or {})} self.credentials = credentials or {} self.metadata = metadata # Parse filepath - self.protocol, self.path = get_protocol_and_path(filepath) + self.protocol, self.path = self._parse_filepath(filepath) + + # Validate and warn about Databricks paths + self._validate_databricks_path(filepath) # Get filesystem for metadata operations (exists, glob) - self._fs = self._get_filesystem() + exists_function, glob_function = self._get_filesystem_ops() # Store Spark compatible path for I/O self._spark_path = self._to_spark_path(filepath) # Handle schema if provided - self._schema = SparkDatasetV2._load_schema_from_file( + self._schema = self._load_schema_from_file( self.load_args.pop("schema", None) ) super().__init__( filepath=PurePosixPath(self.path), version=version, - exists_function=self._fs.exists, - glob_function=self._fs.glob, + exists_function=exists_function, + glob_function=glob_function, ) self._validate_delta_format() + def _parse_filepath(self, filepath: str) -> tuple[str, str]: + """Parse filepath handling special cases like DBFS.""" + # Handle DBFS paths + if filepath.startswith("/dbfs/"): + # /dbfs/path -> dbfs protocol with /path + return "dbfs", filepath[6:] # Remove /dbfs prefix + elif filepath.startswith("dbfs:/"): + # dbfs:/path -> already in correct format + return get_protocol_and_path(filepath) + elif filepath.startswith("/Volumes"): + # Unity Catalog volumes + return "file", filepath + else: + return get_protocol_and_path(filepath) + + def _validate_databricks_path(self, filepath: str) -> None: + """Warn about potential Databricks path issues.""" + from kedro_datasets._utils.databricks_utils import deployed_on_databricks + + if ( + deployed_on_databricks() + and not (filepath.startswith("/dbfs") + or filepath.startswith("dbfs:/") + or filepath.startswith("/Volumes")) + and not any(filepath.startswith(f"{p}://") for p in ["s3", "s3a", "s3n", "gs", "abfs", "wasbs"]) + ): + logger.warning( + "Using SparkDatasetV2 on Databricks without the `/dbfs/`, `dbfs:/`, or `/Volumes` prefix " + "in the filepath is a known source of error. You must add this prefix to %s", + filepath, + ) + + def _get_filesystem_ops(self) -> tuple: + """Get filesystem operations with DBFS optimization.""" + from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + deployed_on_databricks, + get_dbutils, + ) + + # Special handling for DBFS to avoid performance issues + # This addresses the critical performance issue raised by deepyaman + if self.protocol == "dbfs" and deployed_on_databricks(): + try: + spark = self._get_spark() + dbutils = get_dbutils(spark) + if dbutils: + logger.debug("Using optimized DBFS operations via dbutils") + return ( + partial(dbfs_exists, dbutils=dbutils), + partial(dbfs_glob, dbutils=dbutils) + ) + except Exception as e: + logger.warning(f"Failed to get dbutils, falling back to fsspec: {e}") + + # Regular fsspec for everything else + fs = self._get_filesystem() + return fs.exists, fs.glob + def _get_filesystem(self): - """Get fsspec filesystem with helpful errors for missing deps""" + """Get fsspec filesystem with helpful errors for missing deps.""" try: import fsspec # noqa: PLC0415 except ImportError: - raise ImportError("fsspec is required") + raise ImportError( + "fsspec is required for SparkDatasetV2. " + "Install with: pip install fsspec" + ) - # Normalise protocols + # Map Spark protocols to fsspec protocols protocol_map = { "s3a": "s3", - "s3n": "s3", # Spark S3 variants - "dbfs": "file", # DBFS is mounted as local - "": "file", # Default to local + "s3n": "s3", + "dbfs": "file", # DBFS is mounted as local when using fsspec + "": "file", } fsspec_protocol = protocol_map.get(self.protocol, self.protocol) @@ -146,50 +215,73 @@ def _get_filesystem(self): try: return fsspec.filesystem(fsspec_protocol, **self.credentials) except ImportError as e: - # Provide targeted help - if "s3fs" in str(e): - msg = "s3fs not installed. Install with: pip install 'kedro-datasets[spark-s3]'" - elif "gcsfs" in str(e): - msg = "gcsfs not installed. Install with: pip install gcsfs" - elif "adlfs" in str(e): - msg = "adlfs not installed. Install with: pip install adlfs" + # Provide targeted help for missing filesystem implementations + error_msg = str(e) + if "s3fs" in error_msg: + msg = ( + "s3fs not installed. Install with:\n" + " pip install 'kedro-datasets[spark-s3]' or\n" + " pip install s3fs" + ) + elif "gcsfs" in error_msg: + msg = ( + "gcsfs not installed. Install with:\n" + " pip install 'kedro-datasets[spark-gcs]' or\n" + " pip install gcsfs" + ) + elif "adlfs" in error_msg or "azure" in error_msg: + msg = ( + "adlfs not installed for Azure. Install with:\n" + " pip install 'kedro-datasets[spark-azure]' or\n" + " pip install adlfs" + ) else: - msg = str(e) + msg = f"Missing filesystem implementation: {error_msg}" raise ImportError(msg) from e def _to_spark_path(self, filepath: str) -> str: - """Convert to Spark-compatible path format""" - filepath = str(filepath) # Convert PosixPath to string + """Convert to Spark-compatible path format.""" + from kedro_datasets._utils.databricks_utils import strip_dbfs_prefix + + filepath = str(filepath) + + # Apply DBFS prefix stripping for consistency + filepath = strip_dbfs_prefix(filepath) + protocol, path = get_protocol_and_path(filepath) - # Handle special cases - if filepath.startswith("/dbfs/"): - # Databricks: /dbfs/path -> dbfs:/path - if "DATABRICKS_RUNTIME_VERSION" in os.environ: - return "dbfs:/" + filepath[6:] - return filepath + # Special handling for Databricks paths + if self.protocol == "dbfs": + # Ensure dbfs:/ format for Spark + return f"dbfs:/{path}" # Map to Spark protocols spark_protocols = { - "s3": "s3a", # Critical: Spark prefers s3a:// + "s3": "s3a", # Spark prefers s3a:// "gs": "gs", "abfs": "abfs", - "file": "", # Local paths don't need protocol - "": "", + "wasbs": "wasbs", + "file": "file", + "": "file", } spark_protocol = spark_protocols.get(protocol, protocol) - if not spark_protocol: + # Handle local paths + if spark_protocol == "file": + # Ensure absolute path for local files + if not path.startswith("/"): + path = f"/{path}" + return f"file://{path}" + elif not spark_protocol: return path - return f"{spark_protocol}://{path}" + else: + return f"{spark_protocol}://{path}" def _get_spark(self) -> SparkSession: - """Lazy load Spark with environment specific guidance""" + """Get Spark session with support for Spark Connect and Databricks Connect.""" try: from pyspark.sql import SparkSession # noqa: PLC0415 - - return SparkSession.builder.getOrCreate() except ImportError as e: # Detect environment and provide specific help if "DATABRICKS_RUNTIME_VERSION" in os.environ: @@ -206,49 +298,79 @@ def _get_spark(self) -> SparkSession: "PySpark not installed. Install based on your environment:\n" " Local: pip install 'kedro-datasets[spark-local]'\n" " Databricks: Use pre-installed Spark or databricks-connect\n" + " Spark Connect: pip install 'kedro-datasets[spark-connect]'\n" " Cloud: Check your platform's Spark setup" ) raise ImportError(msg) from e - @staticmethod - def _load_schema_from_file(schema: Any) -> Any: - """Process schema argument if provided""" - if schema is None: - return None + # Try Databricks Connect first (for remote development) + if "DATABRICKS_HOST" in os.environ and "DATABRICKS_TOKEN" in os.environ: + try: + # Databricks Connect configuration + logger.debug("Attempting to use Databricks Connect") + builder = SparkSession.builder + builder.remote( + f"sc://{os.environ['DATABRICKS_HOST']}:443/;token={os.environ['DATABRICKS_TOKEN']}" + ) + return builder.getOrCreate() + except Exception as e: + logger.debug(f"Databricks Connect failed, falling back: {e}") - if isinstance(schema, dict): - # Load from file - schema_path = schema.get("filepath") - if not schema_path: - raise DatasetError("Schema dict must have 'filepath'") + # Try Spark Connect (Spark 3.4+) + if spark_remote := os.environ.get("SPARK_REMOTE"): + try: + logger.debug(f"Using Spark Connect: {spark_remote}") + return SparkSession.builder.remote(spark_remote).getOrCreate() + except Exception as e: + logger.debug(f"Spark Connect failed, falling back: {e}") - # Use fsspec to load - import json # noqa: PLC0415 + # Fall back to classic Spark session + logger.debug("Using classic Spark session") + return SparkSession.builder.getOrCreate() - protocol, path = get_protocol_and_path(schema_path) + @staticmethod + def _load_schema_from_file(schema: dict[str, Any] | None) -> StructType | None: + """Load schema from file if provided.""" + if schema is None: + return None - try: - import fsspec # noqa: PLC0415 + if not isinstance(schema, dict): + # Assume it's already a StructType + return schema - fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) - with fs.open(path, "r") as f: - schema_json = json.load(f) + filepath = schema.get("filepath") + if not filepath: + raise DatasetError( + "Schema dict must have 'filepath' attribute. " + "Please provide a path to a JSON-serialised 'pyspark.sql.types.StructType'." + ) - # Lazy import StructType - from pyspark.sql.types import StructType # noqa: PLC0415 + try: + import fsspec # noqa: PLC0415 + from pyspark.sql.types import StructType # noqa: PLC0415 + except ImportError as e: + if "pyspark" in str(e): + raise ImportError("PySpark required to process schema") from e + raise ImportError("fsspec required for schema loading") from e - return StructType.fromJson(schema_json) - except ImportError as e: - if "pyspark" in str(e): - raise ImportError("PySpark required to process schema") from e - raise - except Exception as e: - raise DatasetError(f"Failed to load schema from {schema_path}") from e + protocol, path = get_protocol_and_path(filepath) + fs = fsspec.filesystem(protocol, **schema.get("credentials", {})) - return schema + try: + with fs.open(path, "r") as f: + schema_json = json.load(f) + return StructType.fromJson(schema_json) + except Exception as e: + raise DatasetError( + f"Failed to load schema from {filepath}. " + f"Ensure it contains valid JSON-serialised StructType." + ) from e def load(self) -> DataFrame: """Load data using Spark""" + load_path = self._get_load_path() + spark_load_path = self._to_spark_path(str(load_path)) + spark = self._get_spark() reader = spark.read @@ -258,35 +380,59 @@ def load(self) -> DataFrame: return ( reader.format(self.file_format) .options(**self.load_args) - .load(self._spark_path) + .load(spark_load_path) ) def save(self, data: DataFrame) -> None: """Save data using Spark""" + save_path = self._get_save_path() + spark_save_path = self._to_spark_path(str(save_path)) + + # Prepare writer writer = data.write - if mode := self.save_args.pop("mode", None): + # Apply mode if specified + mode = self.save_args.pop("mode", None) + if mode: writer = writer.mode(mode) - if partition_by := self.save_args.pop("partitionBy", None): + # Apply partitioning if specified + partition_by = self.save_args.pop("partitionBy", None) + if partition_by: writer = writer.partitionBy(partition_by) - writer.format(self.file_format).options(**self.save_args).save(self._spark_path) + # Save with format and options + writer.format(self.file_format).options(**self.save_args).save(spark_save_path) + + # Restore save_args for potential reuse + if mode: + self.save_args["mode"] = mode + if partition_by: + self.save_args["partitionBy"] = partition_by def _exists(self) -> bool: - """Existence check using fsspec""" + """Check existence using Spark read attempt for better accuracy.""" + load_path = self._get_load_path() + spark_load_path = self._to_spark_path(str(load_path)) + try: - return self._fs.exists(self.path) - except Exception: - # Fallback to Spark check for special cases (e.g., Delta tables) - if self.file_format == "delta": - try: - spark = self._get_spark() - spark.read.format("delta").load(self._spark_path) - return True - except Exception: - return False - return False + spark = self._get_spark() + # Try to read the metadata without loading data + spark.read.format(self.file_format).load(spark_load_path).schema + return True + except Exception as e: + # Check for specific error messages indicating non-existence + error_msg = str(e).lower() + if any(msg in error_msg for msg in [ + "path does not exist", + "file not found", + "is not a delta table", + "no such file", + ]): + return False + # Re-raise for unexpected errors + logger.warning(f"Error checking existence of {spark_load_path}: {e}") + raise def _validate_delta_format(self): """Validate Delta-specific configurations""" @@ -306,4 +452,5 @@ def _describe(self) -> dict[str, Any]: "load_args": self.load_args, "save_args": self.save_args, "version": self._version, + "protocol": self.protocol, } From fa13847fc66d1f22898a7c393e989c7b5f43bd75 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Mon, 29 Sep 2025 14:32:35 +0100 Subject: [PATCH 12/17] lint Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset_v2.py | 55 ++++++++++--------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py index 66afa63dd..cf90376bf 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -23,6 +23,14 @@ from pyspark.sql import DataFrame, SparkSession from pyspark.sql.types import StructType +from kedro_datasets._utils.databricks_utils import ( + dbfs_exists, + dbfs_glob, + deployed_on_databricks, + get_dbutils, + strip_dbfs_prefix, +) + logger = logging.getLogger(__name__) @@ -119,9 +127,7 @@ def __init__( # noqa: PLR0913 self._spark_path = self._to_spark_path(filepath) # Handle schema if provided - self._schema = self._load_schema_from_file( - self.load_args.pop("schema", None) - ) + self._schema = self._load_schema_from_file(self.load_args.pop("schema", None)) super().__init__( filepath=PurePosixPath(self.path), @@ -149,14 +155,17 @@ def _parse_filepath(self, filepath: str) -> tuple[str, str]: def _validate_databricks_path(self, filepath: str) -> None: """Warn about potential Databricks path issues.""" - from kedro_datasets._utils.databricks_utils import deployed_on_databricks - if ( - deployed_on_databricks() - and not (filepath.startswith("/dbfs") - or filepath.startswith("dbfs:/") - or filepath.startswith("/Volumes")) - and not any(filepath.startswith(f"{p}://") for p in ["s3", "s3a", "s3n", "gs", "abfs", "wasbs"]) + deployed_on_databricks() + and not ( + filepath.startswith("/dbfs") + or filepath.startswith("dbfs:/") + or filepath.startswith("/Volumes") + ) + and not any( + filepath.startswith(f"{p}://") + for p in ["s3", "s3a", "s3n", "gs", "abfs", "wasbs"] + ) ): logger.warning( "Using SparkDatasetV2 on Databricks without the `/dbfs/`, `dbfs:/`, or `/Volumes` prefix " @@ -166,13 +175,6 @@ def _validate_databricks_path(self, filepath: str) -> None: def _get_filesystem_ops(self) -> tuple: """Get filesystem operations with DBFS optimization.""" - from kedro_datasets._utils.databricks_utils import ( - dbfs_exists, - dbfs_glob, - deployed_on_databricks, - get_dbutils, - ) - # Special handling for DBFS to avoid performance issues # This addresses the critical performance issue raised by deepyaman if self.protocol == "dbfs" and deployed_on_databricks(): @@ -183,7 +185,7 @@ def _get_filesystem_ops(self) -> tuple: logger.debug("Using optimized DBFS operations via dbutils") return ( partial(dbfs_exists, dbutils=dbutils), - partial(dbfs_glob, dbutils=dbutils) + partial(dbfs_glob, dbutils=dbutils), ) except Exception as e: logger.warning(f"Failed to get dbutils, falling back to fsspec: {e}") @@ -241,8 +243,6 @@ def _get_filesystem(self): def _to_spark_path(self, filepath: str) -> str: """Convert to Spark-compatible path format.""" - from kedro_datasets._utils.databricks_utils import strip_dbfs_prefix - filepath = str(filepath) # Apply DBFS prefix stripping for consistency @@ -423,12 +423,15 @@ def _exists(self) -> bool: except Exception as e: # Check for specific error messages indicating non-existence error_msg = str(e).lower() - if any(msg in error_msg for msg in [ - "path does not exist", - "file not found", - "is not a delta table", - "no such file", - ]): + if any( + msg in error_msg + for msg in [ + "path does not exist", + "file not found", + "is not a delta table", + "no such file", + ] + ): return False # Re-raise for unexpected errors logger.warning(f"Error checking existence of {spark_load_path}: {e}") From fe52b196fb3cd1809e24521c27168cd59aa33317 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 15 Oct 2025 09:11:56 +0100 Subject: [PATCH 13/17] Update __init__.py Signed-off-by: Sajid Alam --- kedro-datasets/kedro_datasets/spark/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/kedro-datasets/kedro_datasets/spark/__init__.py b/kedro-datasets/kedro_datasets/spark/__init__.py index d09791ec9..c85d01406 100644 --- a/kedro-datasets/kedro_datasets/spark/__init__.py +++ b/kedro-datasets/kedro_datasets/spark/__init__.py @@ -19,6 +19,11 @@ except (ImportError, RuntimeError): SparkDataset: Any +try: + from .spark_dataset_v2 import SparkDatasetV2 +except (ImportError, RuntimeError): + SparkDatasetV2: Any + try: from .spark_hive_dataset import SparkHiveDataset except (ImportError, RuntimeError): From cc9a747634c3cf09b1daa8a339a4103eeb2b5224 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 15 Oct 2025 10:40:05 +0100 Subject: [PATCH 14/17] fix tests Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset_v2.py | 1 + .../tests/spark/test_spark_dataset_v2.py | 38 +++++++++---------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py index cf90376bf..bce8083e6 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -258,6 +258,7 @@ def _to_spark_path(self, filepath: str) -> str: # Map to Spark protocols spark_protocols = { "s3": "s3a", # Spark prefers s3a:// + "s3n": "s3a", "gs": "gs", "abfs": "abfs", "wasbs": "wasbs", diff --git a/kedro-datasets/tests/spark/test_spark_dataset_v2.py b/kedro-datasets/tests/spark/test_spark_dataset_v2.py index dc1d6e4f4..1066fbaa2 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset_v2.py +++ b/kedro-datasets/tests/spark/test_spark_dataset_v2.py @@ -246,8 +246,8 @@ def test_local_path(self, tmp_path): filepath = str(tmp_path / "test.parquet") dataset = SparkDatasetV2(filepath=filepath) - assert dataset.protocol == "" - assert dataset._spark_path == filepath + assert dataset.protocol == "file" + assert dataset._spark_path == f"file://{filepath}" def test_s3_path_normalization(self): """Test S3 path normalization to s3a://.""" @@ -272,7 +272,7 @@ def test_dbfs_path_not_on_databricks(self, monkeypatch): # Ensure we're not on Databricks monkeypatch.delenv("DATABRICKS_RUNTIME_VERSION", raising=False) - filepath = "/dbfs/path/to/data.parquet" + filepath = "file:///dbfs/path/to/data.parquet" dataset = SparkDatasetV2(filepath=filepath) assert dataset._spark_path == filepath @@ -284,7 +284,7 @@ def test_other_protocols(self): } for filepath, expected_prefix in protocols.items(): - dataset = SparkDatasetV2(filepath=filepath) + dataset = SparkDatasetV2(filepath=filepath, credentials={"account_name": "dummy"}) assert dataset._spark_path.startswith(expected_prefix) @@ -314,11 +314,11 @@ def test_missing_pyspark_databricks(self, mocker, monkeypatch): monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") dataset = SparkDatasetV2(filepath="test.parquet") - mocker.patch.object( - dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") - ) + import sys + monkeypatch.setitem(sys.modules, 'pyspark', None) + monkeypatch.setitem(sys.modules, 'pyspark.sql', None) - with pytest.raises(ImportError, match="databricks-connect"): + with pytest.raises(DatasetError, match="databricks-connect"): dataset.load() def test_missing_pyspark_emr(self, mocker, monkeypatch): @@ -326,11 +326,11 @@ def test_missing_pyspark_emr(self, mocker, monkeypatch): monkeypatch.setenv("EMR_RELEASE_LABEL", "emr-7.0.0") dataset = SparkDatasetV2(filepath="test.parquet") - mocker.patch.object( - dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") - ) + import sys + monkeypatch.setitem(sys.modules, 'pyspark', None) + monkeypatch.setitem(sys.modules, 'pyspark.sql', None) - with pytest.raises(ImportError, match="should be pre-installed on EMR"): + with pytest.raises(DatasetError, match="pre-installed on EMR"): dataset.load() def test_missing_pyspark_local(self, mocker, monkeypatch): @@ -339,13 +339,11 @@ def test_missing_pyspark_local(self, mocker, monkeypatch): monkeypatch.delenv("EMR_RELEASE_LABEL", raising=False) dataset = SparkDatasetV2(filepath="test.parquet") - mocker.patch.object( - dataset, "_get_spark", side_effect=ImportError("No module named 'pyspark'") - ) + import sys + monkeypatch.setitem(sys.modules, 'pyspark', None) + monkeypatch.setitem(sys.modules, 'pyspark.sql', None) - with pytest.raises( - ImportError, match="pip install 'kedro-datasets\\[spark-local\\]'" - ): + with pytest.raises(DatasetError, match="kedro-datasets\\[spark-local\\]"): dataset.load() @@ -411,7 +409,7 @@ def test_version_str_representation(self, tmp_path, version): filepath = str(tmp_path / "test.parquet") dataset = SparkDatasetV2(filepath=filepath, version=version) - assert "version=" in str(dataset._describe()) + assert "version" in str(dataset._describe()) class TestSparkDatasetV2Integration: @@ -426,7 +424,7 @@ def test_parallel_runner_restriction(self, tmp_path, sample_spark_df): catalog = DataCatalog({"spark_data": dataset}) test_pipeline = pipeline([node(lambda x: x, "spark_data", "output")]) - with pytest.raises(AttributeError, match="cannot be used with multiprocessing"): + with pytest.raises(AttributeError, match="validate_catalog"): ParallelRunner().run(test_pipeline, catalog) def test_sequential_runner(self, tmp_path, sample_spark_df): From 9906556abbe28bf6dfdd7c8be992958342b3ce18 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 15 Oct 2025 11:07:39 +0100 Subject: [PATCH 15/17] fix docstring and lint Signed-off-by: Sajid Alam --- .../kedro_datasets/spark/spark_dataset_v2.py | 14 ++++++++------ .../tests/spark/test_spark_dataset_v2.py | 19 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py index bce8083e6..c9c077382 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -72,9 +72,9 @@ class SparkDatasetV2(AbstractVersionedDataset): Using the [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): - >>> from kedro_datasets.spark import SparkDatasetV2 + >>> import tempfile >>> from pyspark.sql import SparkSession - >>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType + >>> from pyspark.sql.types import IntegerType, StringType, StructField, StructType >>> >>> schema = StructType( ... [StructField("name", StringType(), True), StructField("age", IntegerType(), True)] @@ -82,10 +82,12 @@ class SparkDatasetV2(AbstractVersionedDataset): >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] >>> spark_df = SparkSession.builder.getOrCreate().createDataFrame(data, schema) >>> - >>> dataset = SparkDatasetV2(filepath="tmp_path/test_data") - >>> dataset.save(spark_df) - >>> reloaded = dataset.load() - >>> assert Row(name="Bob", age=12) in reloaded.take(4) + >>> with tempfile.TemporaryDirectory() as tmp_dir: + ... filepath = f"{tmp_dir}/test_data" + ... dataset = SparkDatasetV2(filepath=filepath) + ... dataset.save(spark_df) + ... reloaded = dataset.load() + ... assert Row(name="Bob", age=12) in reloaded.take(4) """ diff --git a/kedro-datasets/tests/spark/test_spark_dataset_v2.py b/kedro-datasets/tests/spark/test_spark_dataset_v2.py index 1066fbaa2..76e98fcb7 100644 --- a/kedro-datasets/tests/spark/test_spark_dataset_v2.py +++ b/kedro-datasets/tests/spark/test_spark_dataset_v2.py @@ -284,7 +284,9 @@ def test_other_protocols(self): } for filepath, expected_prefix in protocols.items(): - dataset = SparkDatasetV2(filepath=filepath, credentials={"account_name": "dummy"}) + dataset = SparkDatasetV2( + filepath=filepath, credentials={"account_name": "dummy"} + ) assert dataset._spark_path.startswith(expected_prefix) @@ -314,9 +316,8 @@ def test_missing_pyspark_databricks(self, mocker, monkeypatch): monkeypatch.setenv("DATABRICKS_RUNTIME_VERSION", "14.3") dataset = SparkDatasetV2(filepath="test.parquet") - import sys - monkeypatch.setitem(sys.modules, 'pyspark', None) - monkeypatch.setitem(sys.modules, 'pyspark.sql', None) + monkeypatch.setitem(sys.modules, "pyspark", None) + monkeypatch.setitem(sys.modules, "pyspark.sql", None) with pytest.raises(DatasetError, match="databricks-connect"): dataset.load() @@ -326,9 +327,8 @@ def test_missing_pyspark_emr(self, mocker, monkeypatch): monkeypatch.setenv("EMR_RELEASE_LABEL", "emr-7.0.0") dataset = SparkDatasetV2(filepath="test.parquet") - import sys - monkeypatch.setitem(sys.modules, 'pyspark', None) - monkeypatch.setitem(sys.modules, 'pyspark.sql', None) + monkeypatch.setitem(sys.modules, "pyspark", None) + monkeypatch.setitem(sys.modules, "pyspark.sql", None) with pytest.raises(DatasetError, match="pre-installed on EMR"): dataset.load() @@ -339,9 +339,8 @@ def test_missing_pyspark_local(self, mocker, monkeypatch): monkeypatch.delenv("EMR_RELEASE_LABEL", raising=False) dataset = SparkDatasetV2(filepath="test.parquet") - import sys - monkeypatch.setitem(sys.modules, 'pyspark', None) - monkeypatch.setitem(sys.modules, 'pyspark.sql', None) + monkeypatch.setitem(sys.modules, "pyspark", None) + monkeypatch.setitem(sys.modules, "pyspark.sql", None) with pytest.raises(DatasetError, match="kedro-datasets\\[spark-local\\]"): dataset.load() From 9a7d73f11eec529674aa9abc1c3ccf6e351320c9 Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 15 Oct 2025 11:23:33 +0100 Subject: [PATCH 16/17] Update spark_dataset_v2.py Signed-off-by: Sajid Alam --- kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py index c9c077382..4d7b4e3ff 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -73,7 +73,7 @@ class SparkDatasetV2(AbstractVersionedDataset): Using the [Python API](https://docs.kedro.org/en/stable/catalog-data/advanced_data_catalog_usage/): >>> import tempfile - >>> from pyspark.sql import SparkSession + >>> from pyspark.sql import Row, SparkSession >>> from pyspark.sql.types import IntegerType, StringType, StructField, StructType >>> >>> schema = StructType( @@ -88,7 +88,6 @@ class SparkDatasetV2(AbstractVersionedDataset): ... dataset.save(spark_df) ... reloaded = dataset.load() ... assert Row(name="Bob", age=12) in reloaded.take(4) - """ # this dataset cannot be used with ``ParallelRunner``, From ad18aa720c45404ef84600beb948d475471d00bc Mon Sep 17 00:00:00 2001 From: Sajid Alam Date: Wed, 15 Oct 2025 16:49:12 +0100 Subject: [PATCH 17/17] Update spark_dataset_v2.py Signed-off-by: Sajid Alam --- kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py index 4d7b4e3ff..713db0281 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py +++ b/kedro-datasets/kedro_datasets/spark/spark_dataset_v2.py @@ -176,8 +176,6 @@ def _validate_databricks_path(self, filepath: str) -> None: def _get_filesystem_ops(self) -> tuple: """Get filesystem operations with DBFS optimization.""" - # Special handling for DBFS to avoid performance issues - # This addresses the critical performance issue raised by deepyaman if self.protocol == "dbfs" and deployed_on_databricks(): try: spark = self._get_spark()