Skip to content

Commit 208e70d

Browse files
committed
Merge branch 'master' of github.com:keras-team/keras
2 parents c03e7b0 + f5d3087 commit 208e70d

File tree

2 files changed

+116
-365
lines changed

2 files changed

+116
-365
lines changed

Diff for: keras/src/utils/dataset_utils.py

+55-64
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88

9+
from keras.src import tree
910
from keras.src.api_export import keras_export
1011
from keras.src.utils import io_utils
1112
from keras.src.utils.module_utils import tensorflow as tf
@@ -137,16 +138,7 @@ def _convert_dataset_to_list(
137138
data_size_warning_flag,
138139
start_time,
139140
):
140-
if dataset_type_spec in [tuple, list]:
141-
# The try-except here is for NumPy 1.24 compatibility, see:
142-
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
143-
try:
144-
arr = np.array(sample)
145-
except ValueError:
146-
arr = np.array(sample, dtype=object)
147-
dataset_as_list.append(arr)
148-
else:
149-
dataset_as_list.append(sample)
141+
dataset_as_list.append(sample)
150142

151143
return dataset_as_list
152144

@@ -169,23 +161,23 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
169161
"Please provide a non-empty list of arrays."
170162
)
171163

172-
if _get_type_spec(dataset[0]) is np.ndarray:
173-
expected_shape = dataset[0].shape
174-
for i, element in enumerate(dataset):
175-
if np.array(element).shape[0] != expected_shape[0]:
176-
raise ValueError(
177-
"Received a list of NumPy arrays with different "
178-
f"lengths. Mismatch found at index {i}, "
179-
f"Expected shape={expected_shape} "
180-
f"Received shape={np.array(element).shape}."
181-
"Please provide a list of NumPy arrays with "
182-
"the same length."
183-
)
184-
else:
185-
raise ValueError(
186-
"Expected a list of `numpy.ndarray` objects,"
187-
f"Received: {type(dataset[0])}"
188-
)
164+
expected_shape = None
165+
for i, element in enumerate(dataset):
166+
if not isinstance(element, np.ndarray):
167+
raise ValueError(
168+
"Expected a list of `numpy.ndarray` objects,"
169+
f"Received: {type(element)} at index {i}."
170+
)
171+
if expected_shape is None:
172+
expected_shape = element.shape
173+
elif element.shape[0] != expected_shape[0]:
174+
raise ValueError(
175+
"Received a list of NumPy arrays with different lengths."
176+
f"Mismatch found at index {i}, "
177+
f"Expected shape={expected_shape} "
178+
f"Received shape={np.array(element).shape}."
179+
"Please provide a list of NumPy arrays of the same length."
180+
)
189181

190182
return iter(zip(*dataset))
191183
elif dataset_type_spec == tuple:
@@ -195,23 +187,23 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
195187
"Please provide a non-empty tuple of arrays."
196188
)
197189

198-
if _get_type_spec(dataset[0]) is np.ndarray:
199-
expected_shape = dataset[0].shape
200-
for i, element in enumerate(dataset):
201-
if np.array(element).shape[0] != expected_shape[0]:
202-
raise ValueError(
203-
"Received a tuple of NumPy arrays with different "
204-
f"lengths. Mismatch found at index {i}, "
205-
f"Expected shape={expected_shape} "
206-
f"Received shape={np.array(element).shape}."
207-
"Please provide a tuple of NumPy arrays with "
208-
"the same length."
209-
)
210-
else:
211-
raise ValueError(
212-
"Expected a tuple of `numpy.ndarray` objects, "
213-
f"Received: {type(dataset[0])}"
214-
)
190+
expected_shape = None
191+
for i, element in enumerate(dataset):
192+
if not isinstance(element, np.ndarray):
193+
raise ValueError(
194+
"Expected a tuple of `numpy.ndarray` objects,"
195+
f"Received: {type(element)} at index {i}."
196+
)
197+
if expected_shape is None:
198+
expected_shape = element.shape
199+
elif element.shape[0] != expected_shape[0]:
200+
raise ValueError(
201+
"Received a tuple of NumPy arrays with different lengths."
202+
f"Mismatch found at index {i}, "
203+
f"Expected shape={expected_shape} "
204+
f"Received shape={np.array(element).shape}."
205+
"Please provide a tuple of NumPy arrays of the same length."
206+
)
215207

216208
return iter(zip(*dataset))
217209
elif dataset_type_spec == tf.data.Dataset:
@@ -436,23 +428,24 @@ def _restore_dataset_from_list(
436428
dataset_as_list, dataset_type_spec, original_dataset
437429
):
438430
"""Restore the dataset from the list of arrays."""
439-
if dataset_type_spec in [tuple, list]:
440-
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
441-
elif dataset_type_spec == tf.data.Dataset:
442-
if isinstance(original_dataset.element_spec, dict):
443-
restored_dataset = {}
444-
for d in dataset_as_list:
445-
for k, v in d.items():
446-
if k not in restored_dataset:
447-
restored_dataset[k] = [v]
448-
else:
449-
restored_dataset[k].append(v)
450-
return restored_dataset
451-
else:
452-
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
431+
if dataset_type_spec in [tuple, list, tf.data.Dataset] or is_torch_dataset(
432+
original_dataset
433+
):
434+
# Save structure by taking the first element.
435+
element_spec = dataset_as_list[0]
436+
# Flatten each element.
437+
dataset_as_list = [tree.flatten(sample) for sample in dataset_as_list]
438+
# Combine respective elements at all indices.
439+
dataset_as_list = [np.array(sample) for sample in zip(*dataset_as_list)]
440+
# Recreate the original structure of elements.
441+
dataset_as_list = tree.pack_sequence_as(element_spec, dataset_as_list)
442+
# Turn lists to tuples as tf.data will fail on lists.
443+
return tree.traverse(
444+
lambda x: tuple(x) if isinstance(x, list) else x,
445+
dataset_as_list,
446+
top_down=False,
447+
)
453448

454-
elif is_torch_dataset(original_dataset):
455-
return tuple(np.array(sample) for sample in zip(*dataset_as_list))
456449
return dataset_as_list
457450

458451

@@ -477,14 +470,12 @@ def _get_type_spec(dataset):
477470
return list
478471
elif isinstance(dataset, np.ndarray):
479472
return np.ndarray
480-
elif isinstance(dataset, dict):
481-
return dict
482473
elif isinstance(dataset, tf.data.Dataset):
483474
return tf.data.Dataset
484475
elif is_torch_dataset(dataset):
485-
from torch.utils.data import Dataset as torchDataset
476+
from torch.utils.data import Dataset as TorchDataset
486477

487-
return torchDataset
478+
return TorchDataset
488479
else:
489480
return None
490481

0 commit comments

Comments
 (0)