6
6
7
7
import numpy as np
8
8
9
+ from keras .src import tree
9
10
from keras .src .api_export import keras_export
10
11
from keras .src .utils import io_utils
11
12
from keras .src .utils .module_utils import tensorflow as tf
@@ -137,16 +138,7 @@ def _convert_dataset_to_list(
137
138
data_size_warning_flag ,
138
139
start_time ,
139
140
):
140
- if dataset_type_spec in [tuple , list ]:
141
- # The try-except here is for NumPy 1.24 compatibility, see:
142
- # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
143
- try :
144
- arr = np .array (sample )
145
- except ValueError :
146
- arr = np .array (sample , dtype = object )
147
- dataset_as_list .append (arr )
148
- else :
149
- dataset_as_list .append (sample )
141
+ dataset_as_list .append (sample )
150
142
151
143
return dataset_as_list
152
144
@@ -169,23 +161,23 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
169
161
"Please provide a non-empty list of arrays."
170
162
)
171
163
172
- if _get_type_spec ( dataset [ 0 ]) is np . ndarray :
173
- expected_shape = dataset [ 0 ]. shape
174
- for i , element in enumerate ( dataset ):
175
- if np . array ( element ). shape [ 0 ] != expected_shape [ 0 ]:
176
- raise ValueError (
177
- "Received a list of NumPy arrays with different "
178
- f"lengths. Mismatch found at index { i } , "
179
- f"Expected shape= { expected_shape } "
180
- f"Received shape= { np . array ( element ) .shape } ."
181
- "Please provide a list of NumPy arrays with "
182
- "the same length."
183
- )
184
- else :
185
- raise ValueError (
186
- "Expected a list of `numpy.ndarray` objects, "
187
- f"Received: { type ( dataset [ 0 ]) } "
188
- )
164
+ expected_shape = None
165
+ for i , element in enumerate ( dataset ):
166
+ if not isinstance ( element , np . ndarray ):
167
+ raise ValueError (
168
+ "Expected a list of `numpy.ndarray` objects,"
169
+ f "Received: { type ( element ) } at index { i } . "
170
+ )
171
+ if expected_shape is None :
172
+ expected_shape = element .shape
173
+ elif element . shape [ 0 ] != expected_shape [ 0 ]:
174
+ raise ValueError (
175
+ "Received a list of NumPy arrays with different lengths."
176
+ f"Mismatch found at index { i } , "
177
+ f"Expected shape= { expected_shape } "
178
+ f"Received shape= { np . array ( element ). shape } . "
179
+ "Please provide a list of NumPy arrays of the same length. "
180
+ )
189
181
190
182
return iter (zip (* dataset ))
191
183
elif dataset_type_spec == tuple :
@@ -195,23 +187,23 @@ def _get_data_iterator_from_dataset(dataset, dataset_type_spec):
195
187
"Please provide a non-empty tuple of arrays."
196
188
)
197
189
198
- if _get_type_spec ( dataset [ 0 ]) is np . ndarray :
199
- expected_shape = dataset [ 0 ]. shape
200
- for i , element in enumerate ( dataset ):
201
- if np . array ( element ). shape [ 0 ] != expected_shape [ 0 ]:
202
- raise ValueError (
203
- "Received a tuple of NumPy arrays with different "
204
- f"lengths. Mismatch found at index { i } , "
205
- f"Expected shape= { expected_shape } "
206
- f"Received shape= { np . array ( element ) .shape } ."
207
- "Please provide a tuple of NumPy arrays with "
208
- "the same length."
209
- )
210
- else :
211
- raise ValueError (
212
- "Expected a tuple of `numpy.ndarray` objects, "
213
- f"Received: { type ( dataset [ 0 ]) } "
214
- )
190
+ expected_shape = None
191
+ for i , element in enumerate ( dataset ):
192
+ if not isinstance ( element , np . ndarray ):
193
+ raise ValueError (
194
+ "Expected a tuple of `numpy.ndarray` objects,"
195
+ f "Received: { type ( element ) } at index { i } . "
196
+ )
197
+ if expected_shape is None :
198
+ expected_shape = element .shape
199
+ elif element . shape [ 0 ] != expected_shape [ 0 ]:
200
+ raise ValueError (
201
+ "Received a tuple of NumPy arrays with different lengths."
202
+ f"Mismatch found at index { i } , "
203
+ f"Expected shape= { expected_shape } "
204
+ f"Received shape= { np . array ( element ). shape } . "
205
+ "Please provide a tuple of NumPy arrays of the same length. "
206
+ )
215
207
216
208
return iter (zip (* dataset ))
217
209
elif dataset_type_spec == tf .data .Dataset :
@@ -436,23 +428,24 @@ def _restore_dataset_from_list(
436
428
dataset_as_list , dataset_type_spec , original_dataset
437
429
):
438
430
"""Restore the dataset from the list of arrays."""
439
- if dataset_type_spec in [tuple , list ]:
440
- return tuple (np .array (sample ) for sample in zip (* dataset_as_list ))
441
- elif dataset_type_spec == tf .data .Dataset :
442
- if isinstance (original_dataset .element_spec , dict ):
443
- restored_dataset = {}
444
- for d in dataset_as_list :
445
- for k , v in d .items ():
446
- if k not in restored_dataset :
447
- restored_dataset [k ] = [v ]
448
- else :
449
- restored_dataset [k ].append (v )
450
- return restored_dataset
451
- else :
452
- return tuple (np .array (sample ) for sample in zip (* dataset_as_list ))
431
+ if dataset_type_spec in [tuple , list , tf .data .Dataset ] or is_torch_dataset (
432
+ original_dataset
433
+ ):
434
+ # Save structure by taking the first element.
435
+ element_spec = dataset_as_list [0 ]
436
+ # Flatten each element.
437
+ dataset_as_list = [tree .flatten (sample ) for sample in dataset_as_list ]
438
+ # Combine respective elements at all indices.
439
+ dataset_as_list = [np .array (sample ) for sample in zip (* dataset_as_list )]
440
+ # Recreate the original structure of elements.
441
+ dataset_as_list = tree .pack_sequence_as (element_spec , dataset_as_list )
442
+ # Turn lists to tuples as tf.data will fail on lists.
443
+ return tree .traverse (
444
+ lambda x : tuple (x ) if isinstance (x , list ) else x ,
445
+ dataset_as_list ,
446
+ top_down = False ,
447
+ )
453
448
454
- elif is_torch_dataset (original_dataset ):
455
- return tuple (np .array (sample ) for sample in zip (* dataset_as_list ))
456
449
return dataset_as_list
457
450
458
451
@@ -477,14 +470,12 @@ def _get_type_spec(dataset):
477
470
return list
478
471
elif isinstance (dataset , np .ndarray ):
479
472
return np .ndarray
480
- elif isinstance (dataset , dict ):
481
- return dict
482
473
elif isinstance (dataset , tf .data .Dataset ):
483
474
return tf .data .Dataset
484
475
elif is_torch_dataset (dataset ):
485
- from torch .utils .data import Dataset as torchDataset
476
+ from torch .utils .data import Dataset as TorchDataset
486
477
487
- return torchDataset
478
+ return TorchDataset
488
479
else :
489
480
return None
490
481
0 commit comments