An Essentials Guide to PyTorch Dataset and DataLoader Usage
A brief guide for basic usage of PyTorch's Dataset and DataLoader classes.
- Overview
- Setup
- Dataset Class and Instantiation
- Batching via the DataLoader class
- GPU Usage
- Afterword and Resources
Overview
In this short guide, we show a small representative example using the Dataset
and DataLoader
classes available in PyTorch for easy batching of training examples. This is more meant to be an onboarding for me with fastpages
, but hopefully this example will be useful to those beginning to use PyTorch for their own applications.
#collapse_hide
import torch
We'll then need a dataset to work with. For this small example, we'll use numpy
to generate a random dataset for us. Specifically, we'll be working with a batch size of 32 later, so we'll create a dataset with exactly 50 batches, where each example has 5 features and a corresponding label between 0-9, inclusive. To do so, we use
-
np.random.randn
for generating the input examples -
np.random.randint
for generating the labels
The exact code is shown below.
#collapse_show
import numpy as np
training_examples = np.random.randn(32 * 50, 5)
training_labels = np.random.randint(0, 10, size=(32*50,))
As a sanity check, let's look at the shapes. We'll want the size of the whole dataset to be (1600, 5), as we have $32*50$ examples, each with 5 features. Similarly, we'll want the size of the labels for the whole dataset to be (1600,), as we're essentially working with a list of 1600 labels.
#collapse_show
training_examples.shape, training_labels.shape
We can look at some of the labels, just for a sanity check that they look reasonable.
#collapse_show
training_labels[:10]
Dataset Class and Instantiation
Now, we'll create a simple PyTorch dataset class. All you need to implement within this class is the __getitem__
function and the __len__
function.
-
__getitem__
is a function that takes in an index, and returnsdataset[index]
-
__len__
returns the size of your dataset (in this case, that's 32*50).
When writing this class, you MUST subclass torch.utils.data.Dataset
, as this is requirement for using the DataLoader class (see below).
class ExampleDataset(torch.utils.data.Dataset):
""" You can define the __init__ function any way you like"""
def __init__(self, examples, labels):
self.examples = examples
self.labels = labels
""" This function signature always should take in 1 argument, corresponding to the index you're going to access.
In this case, we're returning a tuple, corresponding to the training example and the corresponding label.
It will also be useful to convert the returned values to torch.Tensors, so we can push the data onto the
GPU later on. Note how the label is put into an array, but the example isn't. This is just a convention:
if we don't put `self.labels[index]` in a list, it'll just create a tensor of zeros with `self.labels[index]` zeros.
"""
def __getitem__(self, index):
return (torch.Tensor(self.examples[index]), torch.Tensor([self.labels[index]]))
""" This function signature always should take in 0 arguments, and the output should be an `int`. """
def __len__(self):
return len(self.examples)
Now, we can instantiate an instance of our ExampleDataset
class, which subclasses torch.utils.data.Dataset
. Note that we can specify how to initialize this via the __init__
function, which takes in a list of examples, and a list of labels (i.e. what we've instantiated in our own setup).
training_dataset = ExampleDataset(training_examples, training_labels)
Sanity check - see the correspondence between accessing the dataset instance of the class above and the examples/labels we passed in.
training_dataset[0]
training_examples[0], training_labels[0]
We can iterate over this dataset using standard for loop syntax as well. The way you write the for loop depends on how __getitem__
is set up. In our case, we return a tuple (example and label), so the for loop should also have a tuple.
example, label = training_dataset[0]
print(type(example), example.shape, type(label), label.shape)
from tqdm import tqdm
for example, label in tqdm(training_dataset):
continue
Batching via the DataLoader class
To set up batching, we'll use the torch.utils.data.DataLoader
class. All we have to do to create this DataLoader is to instantiate it with our dataset we created above (training_dataset
). The arguments for torch.utils.data.DataLoader
are worth looking at, but (generally) most important are:
-
dataset
: the PyTorch dataset class instance we'll pass in (e.g.training_dataset
, this is why we had to do subclassing above) -
batch_size
(optional, default is 1): the batch size we want when iterating (we'll pass in 32) -
shuffle
(optional, default isFalse
): whether we want to shuffle when iterating once the dataloader (note that if this is set to true, it'll shuffle every epoch; note also that we only really want to have this set to true for training, not necessarily for validation) -
drop_last
(optional, default isFalse
): whether to drop the last incomplete batch (we don't have to worry about this because the number of training examples is exactly divisible by 32)
training_dataloader = torch.utils.data.DataLoader(training_dataset, batch_size=32, shuffle=True)
Again, we can iterate, just like we did for training_dataset
, but now, we get batches back, as we can see by printing the shapes. The magic happens in the collate_fn
optional argument of the DataLoader class, but the default behavior is sufficient here for batching the examples and labels separately.
We'll first ensure that there are exactly 50 batches in our dataloader to work with.
assert len(training_dataloader) == 50
Now, mimicking the iteration from above, with the ExampleDataset
instance:
for example, label in tqdm(training_dataloader):
continue
At some point, you may want to know information about a specific batch - accessing specific batches from the DataLoader is not as easy - I don't know of a way to grab a specific batch, other than doing something like the following.
training_dataloader_batches = [(example, label) for example, label in training_dataloader]
some_example, some_label = training_dataloader_batches[15]
some_example.shape, some_label.shape
However, you can always access the underlying dataset by literally doing .dataset
, as shown below.
training_dataloader.dataset
training_dataloader.dataset[15]
if torch.cuda.is_available():
print("Using GPU.")
for example, label in tqdm(training_dataloader):
if torch.cuda.is_available():
example, label = example.cuda(), label.cuda()
Afterword and Resources
As mentioned above, it's useful to look at the documentation for torch.utils.data.DataLoader
. Another way to do so within the notebook itself is to run the following within a cell of the notebook:
torch.utils.data.DataLoader?
There are many interesting things that you can do here, with the arguments allowed in the DataLoader. For example:
- You may be working with an image dataset large enough that you don't want to open all the images (e.g. using
PIL
) before feeding them through your model. In that case, you can lazily open them by passing in acollate_fn
that opens the images before collating the examples of a batch, sincecollate_fn
is only called for each iteration when iterating over the DataLoader, and not when the DataLoader is instantiated. - You may not want to
shuffle
the dataset, as it might incur unnecessary computation. This is especially true if you have a separate DataLoader for your validation dataset, in which case there's no need to shuffle, as it won't affect the predictions. - If you have access to multiple CPUs on whatever machine you're working on, you can use
num_workers
to load batches ahead of time on the other CPUs, i.e. the other workers. - If you're working with a GPU, one of the more expensive steps in the pipeline is moving data from CPU to GPU - this can be sped up by using
pin_memory
, which ensures that the same space in the GPU RAM is used for the data being transferred from the CPU.