Skip to content
161 changes: 161 additions & 0 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
T_PathFileOrDataStore,
_find_absolute_paths,
_normalize_path,
datatree_from_dict_with_io_cleanup,
)
from xarray.backends.locks import get_dask_scheduler
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
Expand Down Expand Up @@ -538,6 +539,35 @@ def _datatree_from_backend_datatree(
return tree


async def _maybe_create_default_indexes_async(ds):
import asyncio

# Determine which coords need default indexes
to_index_names = [
name
for name, coord in ds.coords.items()
if coord.dims == (name,) and name not in ds.xindexes
]

if to_index_names:

async def load_var(var):
try:
return await var.load_async()
except NotImplementedError:
return await asyncio.to_thread(var.load)

await asyncio.gather(
*[load_var(ds.coords[name].variable) for name in to_index_names]
)

# Build indexes (now data is in-memory so no remote I/O per coord)
to_index = {name: ds.coords[name].variable for name in to_index_names}
if to_index:
return ds.assign_coords(Coordinates(to_index))
return ds


def open_dataset(
filename_or_obj: T_PathFileOrDataStore,
*,
Expand Down Expand Up @@ -1253,6 +1283,137 @@ def open_datatree(
return tree


async def open_datatree_async(
filename_or_obj: T_PathFileOrDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | Mapping[str, bool] | None = None,
decode_times: bool
| CFDatetimeCoder
| Mapping[str, bool | CFDatetimeCoder]
| None = None,
decode_timedelta: bool
| CFTimedeltaCoder
| Mapping[str, bool | CFTimedeltaCoder]
| None = None,
use_cftime: bool | Mapping[str, bool] | None = None,
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
create_default_indexes: bool = True,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> DataTree:
"""Async version of open_datatree that concurrently builds default indexes.

Supports the "zarr" engine (both Zarr v2 and v3). For other engines, a
ValueError is raised.
"""
import asyncio

if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

# Only zarr supports async lazy loading at present
if engine != "zarr":
raise ValueError(f"Engine {engine!r} does not support asynchronous operations")

backend = plugins.get_backend(engine)

decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=backend.open_dataset_parameters,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)

overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)

# Prefer backend async group opening if available (currently zarr only)
if hasattr(backend, "open_groups_as_dict_async"):
groups_dict = await backend.open_groups_as_dict_async(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
backend_tree = datatree_from_dict_with_io_cleanup(groups_dict)
else:
backend_tree = backend.open_datatree(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

# Protect variables for caching behavior consistency
_protect_datatree_variables_inplace(backend_tree, cache)

# For each dataset in the tree, concurrently create default indexes (if requested)
results: dict[str, Dataset] = {}

async def process_node(path: str, node_ds: Dataset) -> tuple[str, Dataset]:
ds = node_ds
if create_default_indexes:
ds = await _maybe_create_default_indexes_async(ds)
# Optional chunking (synchronous)
if chunks is not None:
ds = _chunk_ds(
ds,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
node=path,
**decoders,
**kwargs,
)
return path, ds

# Build tasks
tasks = [
process_node(path, node.dataset)
for path, [node] in group_subtrees(backend_tree)
]

# Execute concurrently and collect
for fut in asyncio.as_completed(tasks):
path, ds = await fut
results[path] = ds

# Build DataTree
tree = DataTree.from_dict(results)

# Carry over close handlers from backend tree when needed (mirrors sync path)
if create_default_indexes or chunks is not None:
for _path, [node] in group_subtrees(backend_tree):
tree[_path].set_close(node._close)

return tree


def open_groups(
filename_or_obj: T_PathFileOrDataStore,
*,
Expand Down
75 changes: 75 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import json
import os
Expand Down Expand Up @@ -1785,6 +1786,80 @@ def open_groups_as_dict(
groups_dict[group_name] = group_ds
return groups_dict

async def open_groups_as_dict_async(
self,
filename_or_obj: T_PathFileOrDataStore,
*,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group: str | None = None,
mode="r",
synchronizer=None,
consolidated=None,
chunk_store=None,
storage_options=None,
zarr_version=None,
zarr_format=None,
) -> dict[str, Dataset]:
"""Asynchronously open each group into a Dataset concurrently.

This mirrors open_groups_as_dict but parallelizes per-group Dataset opening,
which can significantly reduce latency on high-RTT object stores.
"""
filename_or_obj = _normalize_path(filename_or_obj)

# Determine parent group path context
if group:
parent = str(NodePath("/") / NodePath(group))
else:
parent = str(NodePath("/"))

# Discover group stores (synchronous metadata step)
stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
)

async def open_one(path_group: str, store) -> tuple[str, Dataset]:
store_entrypoint = StoreBackendEntrypoint()

def _load_sync():
with close_on_error(store):
return store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)

ds = await asyncio.to_thread(_load_sync)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
return group_name, ds

tasks = [open_one(path_group, store) for path_group, store in stores.items()]
results = await asyncio.gather(*tasks)
return dict(results)


def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:
parent_nodepath = NodePath(parent)
Expand Down
Loading
Loading