Batch Dataloaders on Pytorch
The suggested way to work with DataLoader
s and Datasets
on most pytorch
resources is fine for working with images, but when it comes to tabular data of
the sort that you could store on a numpy array it is slow and inefficient. We
are usually suggested to produce samples one at a time, the moral equivalent
of:
for idx in indices_of_samples_for_next_minibatch:
return data[i, :]
If your data were loaded into memory from the start into something like a numpy array you would have normally simply done:
return data[indices_of_samples_for_next_minibatch, :]
Much simpler and also an order of magnitude faster!
Datasets, Dataloaders and __getitem__
If you've ever built your own Dataset
you will have come by the __getitem__
method, it's role is to produce a sample given an index into the dataset. The
idea is that the DataLoader
generates a bunch of indices that will be used to
build the next mini-batch. Then, it essentially does a for loop over the
indices requesting samples from the dataset via __getitem__
and produces the
mini-batch. You may have heard to avoid for loops in Python except for trivial
non-performance critical parts because they are slow. Well, producing samples
to feed a neural network is neither trivial nor non-performance critical. To
ameliorate this, pytorch makes it possible to use multipe workers to fetch
data. Also, very importantly, one of the main use cases for neural networks is
image problems. In this case the dataset don't fit into memory most of the time
so we usually just load it from disk as required. In this case we can't just do
what I suggested on the introduction; but what about when we can do it? In that
case no obvious solution is provided on the documentation and much less on all
those copy-pasta tutorials out there.
The solution
What we would really want is for the dataloader to pass all the indices that
will go into building the mini-batch in one go to __getitem__
.
The key insight for the code below comes from:
.... but setting
batch_size=None
and using a sampler that yields a collection of indices at a time. In that way, yourdataset.__getitem__
will receive a collection of indices, and thecollate_fn
will only convert np arrays to tensors (no collating anything into batches).Ssnl @ Batched Dataloader #26957, 01/10/2019
If we follow the advice and set batch_size=None
we disable the automatic
batching as per the
documentation.
Then we just need to add our own sampler that will return batches of indices,
i.e. the BatchSampler
. We still need to tell the BatchSampler
how to to
actually decide what indices to return. I went for the most common scenario
here which is to return random indices but there are a few other options
available.
Without further ado here is the code!
The magic is on lines 23-29 where the DataLoader
is defined. The rest is just
a numpy array wrapped around a dataset and some test to make sure that we are
doing something reasonable. The __getitem__
there would also be an ideal
place to perform data augmentations and transformations on a vectorized way
giving as an even greater performance boost over the one at a time default
approach.
See you soon and Godspeed!