diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c3f1f42 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.pyc +*.egg-info +*.DS_Store +.*swp +__pycache__ +.pytest_cache \ No newline at end of file diff --git a/brainio/assemblies.py b/brainio/assemblies.py index 13a19b4..a58d647 100644 --- a/brainio/assemblies.py +++ b/brainio/assemblies.py @@ -431,14 +431,36 @@ def get_levels(assembly): def gather_indexes(assembly): - """This is only necessary as long as xarray cannot persist MultiIndex to netCDF. """ + """This is only necessary as long as xarray cannot persist MultiIndex to netCDF.""" coords_d = {} for dim in assembly.dims: - coord_names = list(get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False)) + coord_names = list( + get_metadata( + assembly, + dims=(dim,), + names_only=True, + include_indexes=False, + include_levels=False, + ) + ) if coord_names: coords_d[dim] = coord_names + + # fix single-coord-single-dim + to_stack = {} + for dim, coords in coords_d.items(): + if len(coords) == 1 and len(list(get_metadata(assembly, dims=(dim,)))) == 1: + to_stack[dim] = coords[0] + if coords_d: assembly = assembly.set_index(append=True, **coords_d) + + if to_stack: + # single-coord stacking trick + for dim, coord in to_stack.items(): + assembly = assembly.rename({dim: coord}) + assembly = assembly.stack({dim: [coord]}) + return assembly diff --git a/tests/test_assemblies.py b/tests/test_assemblies.py index eaeb626..b5215eb 100644 --- a/tests/test_assemblies.py +++ b/tests/test_assemblies.py @@ -67,7 +67,7 @@ def test_get_levels(): }, dims=['a', 'b'] ) - assert get_levels(assy) == ["up", "down"] + assert get_levels(assy) == ["up", "down", "sideways"] class TestSubclassing: @@ -115,7 +115,7 @@ def test_reset_index(self): dims=['a', 'b'] ) da = DataArray(assy) - da = da.reset_index(["up", "down"]) + da = da.reset_index(["up", "down", "sideways"]) assert get_levels(da) == [] def test_repr(self): diff --git a/tests/test_packaging.py b/tests/test_packaging.py index cd78a24..0433d76 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -61,6 +61,8 @@ def test_reset_index_levels(): ) assert assy["a"].variable.level_names == ["up", "down"] assy = assy.reset_index(["up", "down"]) + assert get_levels(assy) == ["sideways"] + assy = assy.reset_index(["sideways"]) assert get_levels(assy) == []