From 7dd1bb02a50f0097d96b3619acccd2768b521551 Mon Sep 17 00:00:00 2001 From: Alexander Werning Date: Wed, 15 Nov 2023 18:51:14 +0000 Subject: [PATCH] Add TileDataset --- lazy_dataset/core.py | 101 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 12 deletions(-) diff --git a/lazy_dataset/core.py b/lazy_dataset/core.py index ed8a59f..0b9ab0d 100644 --- a/lazy_dataset/core.py +++ b/lazy_dataset/core.py @@ -1004,7 +1004,7 @@ def shuffle(self, reshuffle: bool = False, else: raise ValueError(reshuffle, self) - def tile(self, reps: int, shuffle: bool = False) -> 'Dataset': + def tile(self, reps: int, shuffle: bool = False) -> "Dataset": """ Constructs a new dataset by repeating the dataset the number of times given by `reps`. This is done by copying the dataset and @@ -1022,21 +1022,16 @@ def tile(self, reps: int, shuffle: bool = False) -> 'Dataset': >>> ds ListDataset(len=5) MapDataset(_pickle.loads) - ListDataset(len=5) - MapDataset(_pickle.loads) - ListDataset(len=5) - MapDataset(_pickle.loads) - ConcatenateDataset() + TileDataset(repetitions=3) >>> list(ds) [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5] """ - datasets = [self] * reps if shuffle: - datasets = [ - ds.shuffle() - for ds in datasets - ] - return self.__class__.concatenate(*datasets) + datasets = [self] * reps + datasets = [ds.shuffle() for ds in datasets] + return self.__class__.concatenate(*datasets) + else: + return TileDataset(self, reps) def cycle(self) -> 'CycleDataset': """ @@ -2763,6 +2758,88 @@ def __getitem__(self, item): return super().__getitem__(item) +class TileDataset(Dataset): + """ + Iterates over all elements of the input_dataset for `repetitions` times. + + """ + + def __init__(self, input_dataset, repetitions): + """ + + Args: + input_dataset: dataset + repetitions: int + + """ + self.input_dataset = input_dataset + self.repetitions = repetitions + + def copy(self, freeze=False): + return self.__class__(self.input_dataset.copy(freeze=freeze), self.repetitions) + + @property + def indexable(self): + return self.input_dataset.indexable + + @property + def ordered(self) -> bool: + return self.input_dataset.ordered + + def __str__(self): + return f"{self.__class__.__name__}(repetitions={self.repetitions})" + + def __iter__(self, with_key=False): + for _ in range(self.repetitions): + if with_key: + iterable = self.input_dataset.__iter__(with_key=True) + else: + iterable = self.input_dataset + for example in iterable: + yield example + + def __len__(self): + return self.repetitions * len(self.input_dataset) + + def __getitem__(self, item): + """ + >>> ds = DictDataset({'a': {}, 'b': {}}) + >>> ds = ds.items().map(lambda x: {'example_id': x[0], **x[1]}) + >>> ds = ds.tile(2) + >>> len(ds) + 4 + >>> ds['a'] + {'example_id': 'a'} + >>> ds['b'] + {'example_id': 'b'} + >>> ds[5] + Traceback (most recent call last): + ... + IndexError: 5 + >>> ds[-1] + {'example_id': 'b'} + >>> ds[-5] + Traceback (most recent call last): + ... + IndexError: -5 + + """ + if isinstance(item, str): + return self.input_dataset[item] + elif isinstance(item, numbers.Integral): + _item = item + if item < 0: + item = item + len(self) + if item < 0: + raise IndexError(_item) + if item > self.repetitions * len(self.input_dataset): + raise IndexError(_item) + item = item % len(self.input_dataset) + return self.input_dataset[item] + else: + return super().__getitem__(item) + + class IntersperseDataset(Dataset): """ See Dataset.intersperse