Skip to content

dynamic padding via collate_fn #761

@Jomonsugi

Description

@Jomonsugi

I would like to dynamically pad my tensors by way of the collate_fn argument that can be passed to petastorm.pytorch.DataLoader, but I am seemingly thwarted by make_batch_reader here, thus it appears make_batch_reader prevents the user from shoring up tensor size through the dataloader.

Or is this possible and I'm just missing how to do so? collate_fn can take care of the variable length values on a batch by batch basis. Otherwise it seems like I'd need to pad all the data in my spark data frame which increases data size substantially, slows training and I assume i/o through petastorm in general.

What I would like to do looks something like below where the function passed to collate_fun would dynamically pad my variable length values.

reader = make_batch_reader(
        channel,
        workers_count=2,
        num_epochs=1,
        schema_fields=['input', 'labels']
    )

dl = DataLoader(reader,
                batch_size = 8,
                shuffling_queue_capacity = 100000,
                collate_fn=some_padding_function
               )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions