Skip to content

Rechunk derived #6516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 116 additions & 27 deletions lib/iris/aux_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dask.array as da
import numpy as np

from iris._lazy_data import concatenate
from iris._lazy_data import _optimum_chunksize, concatenate, is_lazy_data
from iris.common import CFVariableMixin, CoordMetadata, metadata_manager_factory
import iris.coords
from iris.warnings import IrisIgnoringBoundsWarning
Expand Down Expand Up @@ -76,6 +76,93 @@ def dependencies(self):

"""

@abstractmethod
def _calculate_array(self, *dep_arrays, **other_args):
"""Make a coordinate array from a complete set of dependency arrays.

Parameters
----------
* dep_arrays : tuple of array-like
Arrays of data for each dependency.
Must match the number of declared dependencies, in the standard order.
All are aligned with the leading result dimensions, but may have fewer
than the full number of dimensions. They can be lazy or real data.

* other_args
Dict of keys providing class-specific additional arguments.

Returns
-------
array-like
The lazy result array.

This is the basic derived calculation, defined by each hybrid class, which
defines how the dependency values are combined to make the derived result.
"""
pass

def _derive_array(self, *dep_arrays, **other_args):
"""Build an array of coordinate values.

Call arguments as for :meth:`_calculate_array`.

This routine calls :meth:`_calculate_array` to construct a derived result array.

It then checks the chunk size of the result and, if this exceeds the current
Dask chunksize, it will then re-chunk some of the input arrays and re-calculate
the result to reduce the memory cost.

This routine is itself usually called once by :meth:`make_coord`, to make a
points array, and then again to make the bounds.
"""
# Make an initial result calculation.
# First make all dependencies lazy, to ensure a lazy calculation and avoid
# potentially spending a lot of time + memory.
lazy_deps = [
# Note: no attempt to make clever chunking choices here. If needed it
# should get fixed later. Plus, single chunks keeps graph overhead small.
dep if is_lazy_data(dep) else da.from_array(dep, chunks=-1)
for dep in dep_arrays
]
result = self._calculate_array(*lazy_deps, **other_args)

# Now check if we need to improve on the chunking of the result.
adjusted_chunks = _optimum_chunksize(
chunks=result.chunksize,
shape=result.shape,
dtype=result.dtype,
)

# Does optimum_chunksize say we should have smaller chunks in some dimensions?
if not all(a >= b for a, b in zip(adjusted_chunks, result.chunksize)):
# Re-do the result calculation, but first re-chunking each dep in the
# dimensions which it is suggested to reduce.
new_deps = []
for dep, original_dep in zip(lazy_deps, dep_arrays):
# For each dependency, reduce chunksize in each dim to the new result
# chunksize, if smaller.
dep_chunks = dep.chunksize
new_chunks = tuple(
[
min(dep_chunk, adj_chunk)
for dep_chunk, adj_chunk in zip(dep_chunks, adjusted_chunks)
]
)
if new_chunks != dep_chunks:
# When dep chunksize needs to change, produce a rechunked version.
if is_lazy_data(original_dep):
dep = original_dep.rechunk(new_chunks)
else:
# Make new lazy array from real original, rather than re-chunk.
dep = da.from_array(original_dep, chunks=new_chunks)
new_deps.append(dep)

# Finally, re-do the calculation, which hopefully results in a better
# overall chunksize for the result
result = self._calculate_array(*new_deps, **other_args)

return result

@abstractmethod
def make_coord(self, coord_dims_func):
"""Return a new :class:`iris.coords.AuxCoord` as defined by this factory.
Expand Down Expand Up @@ -463,7 +550,7 @@ def dependencies(self):
return dependencies

@staticmethod
def _derive(pressure_at_top, sigma, surface_air_pressure):
def _calculate_array(pressure_at_top, sigma, surface_air_pressure):
"""Derive coordinate."""
return pressure_at_top + sigma * (surface_air_pressure - pressure_at_top)

Expand All @@ -485,7 +572,7 @@ def make_coord(self, coord_dims_func):

# Build the points array
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["pressure_at_top"],
nd_points_by_key["sigma"],
nd_points_by_key["surface_air_pressure"],
Expand Down Expand Up @@ -519,7 +606,7 @@ def make_coord(self, coord_dims_func):
surface_air_pressure_pts = nd_points_by_key["surface_air_pressure"]
bds_shape = list(surface_air_pressure_pts.shape) + [1]
surface_air_pressure = surface_air_pressure_pts.reshape(bds_shape)
bounds = self._derive(pressure_at_top, sigma, surface_air_pressure)
bounds = self._derive_array(pressure_at_top, sigma, surface_air_pressure)

# Create coordinate
return iris.coords.AuxCoord(
Expand Down Expand Up @@ -608,7 +695,7 @@ def dependencies(self):
"orography": self.orography,
}

def _derive(self, delta, sigma, orography):
def _calculate_array(self, delta, sigma, orography):
return delta + sigma * orography

def make_coord(self, coord_dims_func):
Expand All @@ -629,7 +716,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["delta"],
nd_points_by_key["sigma"],
nd_points_by_key["orography"],
Expand Down Expand Up @@ -657,7 +744,7 @@ def make_coord(self, coord_dims_func):
bds_shape = list(orography_pts.shape) + [1]
orography = orography_pts.reshape(bds_shape)

bounds = self._derive(delta, sigma, orography)
bounds = self._derive_array(delta, sigma, orography)

hybrid_height = iris.coords.AuxCoord(
points,
Expand Down Expand Up @@ -814,7 +901,7 @@ def dependencies(self):
"surface_air_pressure": self.surface_air_pressure,
}

def _derive(self, delta, sigma, surface_air_pressure):
def _calculate_array(self, delta, sigma, surface_air_pressure):
return delta + sigma * surface_air_pressure

def make_coord(self, coord_dims_func):
Expand All @@ -835,7 +922,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["delta"],
nd_points_by_key["sigma"],
nd_points_by_key["surface_air_pressure"],
Expand Down Expand Up @@ -863,7 +950,7 @@ def make_coord(self, coord_dims_func):
bds_shape = list(surface_air_pressure_pts.shape) + [1]
surface_air_pressure = surface_air_pressure_pts.reshape(bds_shape)

bounds = self._derive(delta, sigma, surface_air_pressure)
bounds = self._derive_array(delta, sigma, surface_air_pressure)

hybrid_pressure = iris.coords.AuxCoord(
points,
Expand Down Expand Up @@ -1022,7 +1109,9 @@ def dependencies(self):
zlev=self.zlev,
)

def _derive(self, sigma, eta, depth, depth_c, zlev, nsigma, coord_dims_func):
def _calculate_array(
self, sigma, eta, depth, depth_c, zlev, nsigma, coord_dims_func
):
# Calculate the index of the 'z' dimension in the input arrays.
# First find the cube 'z' dimension ...
[cube_z_dim] = coord_dims_func(self.dependencies["zlev"])
Expand Down Expand Up @@ -1097,14 +1186,14 @@ def make_coord(self, coord_dims_func):
nd_points_by_key = self._remap(dependency_dims, derived_dims)

[nsigma] = nd_points_by_key["nsigma"]
points = self._derive(
points = self._derive_array(
nd_points_by_key["sigma"],
nd_points_by_key["eta"],
nd_points_by_key["depth"],
nd_points_by_key["depth_c"],
nd_points_by_key["zlev"],
nsigma,
coord_dims_func,
coord_dims_func=coord_dims_func,
)

bounds = None
Expand All @@ -1131,14 +1220,14 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["sigma"],
nd_values_by_key["eta"],
nd_values_by_key["depth"],
nd_values_by_key["depth_c"],
nd_values_by_key["zlev"],
nsigma,
coord_dims_func,
coord_dims_func=coord_dims_func,
)

coord = iris.coords.AuxCoord(
Expand Down Expand Up @@ -1238,7 +1327,7 @@ def dependencies(self):
"""
return dict(sigma=self.sigma, eta=self.eta, depth=self.depth)

def _derive(self, sigma, eta, depth):
def _calculate_array(self, sigma, eta, depth):
return eta + sigma * (depth + eta)

def make_coord(self, coord_dims_func):
Expand All @@ -1257,7 +1346,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["sigma"],
nd_points_by_key["eta"],
nd_points_by_key["depth"],
Expand Down Expand Up @@ -1287,7 +1376,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["sigma"],
nd_values_by_key["eta"],
nd_values_by_key["depth"],
Expand Down Expand Up @@ -1419,7 +1508,7 @@ def dependencies(self):
depth_c=self.depth_c,
)

def _derive(self, s, c, eta, depth, depth_c):
def _calculate_array(self, s, c, eta, depth, depth_c):
S = depth_c * s + (depth - depth_c) * c
return S + eta * (1 + S / depth)

Expand All @@ -1439,7 +1528,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["s"],
nd_points_by_key["c"],
nd_points_by_key["eta"],
Expand Down Expand Up @@ -1471,7 +1560,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["s"],
nd_values_by_key["c"],
nd_values_by_key["eta"],
Expand Down Expand Up @@ -1608,7 +1697,7 @@ def dependencies(self):
depth_c=self.depth_c,
)

def _derive(self, s, eta, depth, a, b, depth_c):
def _calculate_array(self, s, eta, depth, a, b, depth_c):
c = (1 - b) * da.sinh(a * s) / da.sinh(a) + b * (
da.tanh(a * (s + 0.5)) / (2 * da.tanh(0.5 * a)) - 0.5
)
Expand All @@ -1630,7 +1719,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["s"],
nd_points_by_key["eta"],
nd_points_by_key["depth"],
Expand Down Expand Up @@ -1663,7 +1752,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["s"],
nd_values_by_key["eta"],
nd_values_by_key["depth"],
Expand Down Expand Up @@ -1799,7 +1888,7 @@ def dependencies(self):
depth_c=self.depth_c,
)

def _derive(self, s, c, eta, depth, depth_c):
def _calculate_array(self, s, c, eta, depth, depth_c):
S = (depth_c * s + depth * c) / (depth_c + depth)
return eta + (eta + depth) * S

Expand All @@ -1819,7 +1908,7 @@ def make_coord(self, coord_dims_func):

# Build the points array.
nd_points_by_key = self._remap(dependency_dims, derived_dims)
points = self._derive(
points = self._derive_array(
nd_points_by_key["s"],
nd_points_by_key["c"],
nd_points_by_key["eta"],
Expand Down Expand Up @@ -1851,7 +1940,7 @@ def make_coord(self, coord_dims_func):
bounds = nd_points_by_key[key].reshape(bds_shape)
nd_values_by_key[key] = bounds

bounds = self._derive(
bounds = self._derive_array(
nd_values_by_key["s"],
nd_values_by_key["c"],
nd_values_by_key["eta"],
Expand Down
18 changes: 9 additions & 9 deletions lib/iris/tests/unit/aux_factory/test_AtmosphereSigmaFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,21 @@ def test_values(self, sample_kwargs):

class Test__derive:
def test_function_scalar(self):
assert AtmosphereSigmaFactory._derive(0, 0, 0) == 0
assert AtmosphereSigmaFactory._derive(3, 0, 0) == 3
assert AtmosphereSigmaFactory._derive(0, 5, 0) == 0
assert AtmosphereSigmaFactory._derive(0, 0, 7) == 0
assert AtmosphereSigmaFactory._derive(3, 5, 0) == -12
assert AtmosphereSigmaFactory._derive(3, 0, 7) == 3
assert AtmosphereSigmaFactory._derive(0, 5, 7) == 35
assert AtmosphereSigmaFactory._derive(3, 5, 7) == 23
assert AtmosphereSigmaFactory._calculate_array(0, 0, 0) == 0
assert AtmosphereSigmaFactory._calculate_array(3, 0, 0) == 3
assert AtmosphereSigmaFactory._calculate_array(0, 5, 0) == 0
assert AtmosphereSigmaFactory._calculate_array(0, 0, 7) == 0
assert AtmosphereSigmaFactory._calculate_array(3, 5, 0) == -12
assert AtmosphereSigmaFactory._calculate_array(3, 0, 7) == 3
assert AtmosphereSigmaFactory._calculate_array(0, 5, 7) == 35
assert AtmosphereSigmaFactory._calculate_array(3, 5, 7) == 23

def test_function_array(self):
ptop = 3
sigma = np.array([2, 4])
ps = np.arange(4).reshape(2, 2)
np.testing.assert_equal(
AtmosphereSigmaFactory._derive(ptop, sigma, ps),
AtmosphereSigmaFactory._calculate_array(ptop, sigma, ps),
[[-3, -5], [1, 3]],
)

Expand Down
Loading
Loading