Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.pyc
*.egg-info
*.DS_Store
.*swp
__pycache__
.pytest_cache
26 changes: 24 additions & 2 deletions brainio/assemblies.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,14 +431,36 @@ def get_levels(assembly):


def gather_indexes(assembly):
"""This is only necessary as long as xarray cannot persist MultiIndex to netCDF. """
"""This is only necessary as long as xarray cannot persist MultiIndex to netCDF."""
coords_d = {}
for dim in assembly.dims:
coord_names = list(get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False))
coord_names = list(
get_metadata(
assembly,
dims=(dim,),
names_only=True,
include_indexes=False,
include_levels=False,
)
)
if coord_names:
coords_d[dim] = coord_names

# fix single-coord-single-dim
to_stack = {}
for dim, coords in coords_d.items():
if len(coords) == 1 and len(list(get_metadata(assembly, dims=(dim,)))) == 1:
to_stack[dim] = coords[0]

if coords_d:
assembly = assembly.set_index(append=True, **coords_d)

if to_stack:
# single-coord stacking trick
for dim, coord in to_stack.items():
assembly = assembly.rename({dim: coord})
assembly = assembly.stack({dim: [coord]})

return assembly


Expand Down
4 changes: 2 additions & 2 deletions tests/test_assemblies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_get_levels():
},
dims=['a', 'b']
)
assert get_levels(assy) == ["up", "down"]
assert get_levels(assy) == ["up", "down", "sideways"]


class TestSubclassing:
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_reset_index(self):
dims=['a', 'b']
)
da = DataArray(assy)
da = da.reset_index(["up", "down"])
da = da.reset_index(["up", "down", "sideways"])
assert get_levels(da) == []

def test_repr(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_reset_index_levels():
)
assert assy["a"].variable.level_names == ["up", "down"]
assy = assy.reset_index(["up", "down"])
assert get_levels(assy) == ["sideways"]
assy = assy.reset_index(["sideways"])
assert get_levels(assy) == []


Expand Down