Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

🏎️ FFCV DataLoaders#

Itching to use optimized data loading systems like FFCV? This tutorial will show you how to do so with Composer.

Tutorial Goals and Covered Concepts#

The goal of this tutorial is to walk you through an example of using FFCV dataloaders with the Composer training loop. For the sake of demonstration, we’ll show the training loop twice—first with PyTorch dataloaders (as a baseline) and then with FFCV dataloaders for comparison.

We’ll be using the CIFAR-10 dataset for demonstration purposes but you can use ImageNet-1K (and others) as well.

Note: This notebook may not work in Google colab due to FFCV’s requirement for Python >= 3.8 and Google colab running Python 3.7 as of May 15, 2022.

Another Note: To get the most out of FFCV with Composer, you’ll need to run ffcv_monkey_patches() once in the start of your training script. More detail below.

Let’s get started!

Install Composer#

We’ll start by installing Composer and FFCV:

[ ]:
!apt update && apt install -y --no-install-recommends libopencv-dev libturbojpeg-dev
!cp -f /usr/lib/x86_64-linux-gnu/pkgconfig/opencv.pc /usr/lib/x86_64-linux-gnu/pkgconfig/opencv4.pc
%pip ffcv numba opencv-python

%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git
[ ]:
import torch
cuda_ver = torch.version.cuda.replace(".", "")
%pip install cupy-cuda{cuda_ver}

Establishing a Baseline#

The rest of this tutorial is roughly divided into two sections, one for each run—with and without FFCV.

In this first section, we’ll set up our environment for training with Composer on CIFAR-10 using standard PyTorch dataloaders. Our goal here is just to set up a baseline for comparison with the next section, where we bring FFCV into the picture.

Imports#

First, the imports:

[ ]:
import time

import composer
from torchvision import datasets, transforms

torch.manual_seed(42) # For replicability

Dataset & DataLoader#

Next, we instantiate our CIFAR-10 dataset and dataloader. We’ll use the Torchvision CIFAR-10 and PyTorch dataloader for the sake of familiarity.

[ ]:
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024
num_workers = 2
data_directory = "/tmp"

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               num_workers=num_workers,
                                               batch_size=batch_size,
                                               pin_memory=True,
                                               drop_last=True,
                                               shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=num_workers,
                                              batch_size=batch_size,
                                              pin_memory=True,
                                              drop_last=False,
                                              shuffle=False)

Model#

Next, we create our model. We’ll use Composer’s built-in ResNet18. To use your own custom model, please see the custom models tutorial.

[ ]:
from composer import models
model = models.composer_resnet_cifar(model_name='resnet_20', num_classes=10)

Optimizer and Scheduler#

We’ll use MosaicML’s SGD with decoupled weight decay as the optimizer. We just need to create the optimizer and LR scheduler instances, and the trainer (below) will handle the rest:

[ ]:
optimizer = composer.optim.DecoupledSGDW(
    model.parameters(), # Model parameters to update
    lr=0.05, # Peak learning rate
    momentum=0.9,
    weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)

To keep the runtime short, we’ll train our baseline model for five epochs. The first epoch will be linear warmup, followed by four epochs of constant LR. We achieve this by instantiating a LinearWithWarmupScheduler class. Feel free to increase the number of epochs in case you want to see the impact of running it for a longer duration.

[ ]:
lr_scheduler = composer.optim.LinearWithWarmupScheduler(
    t_warmup="1ep", # Warm up over 1 epoch
    alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f
    alpha_f=1.0
)

Train a Baseline Model#

And now we create our trainer! This pattern should look pretty familiar if you’ve been working through the tutorials.

Note: We want to use a GPU as our device because FFCV works the best on GPU-capable machines.

[ ]:
train_epochs = "5ep" # Train for 5 epochs
device = "gpu"

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device
)

We train and measure the training time below.

[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
print(f"It took {end_time - start_time:0.4f} seconds to train")

Depending on where you are running this notebook, the runtime may vary based on the machine status. We found that the five epochs of training could take anywhere from 23-25 seconds to run and the mean validation accuracy was typically close to ~62%.

Use FFCV Dataloaders to Speed Up Training#

Now we’re on to the second section of our tutorial. Here, we’ll see how to add FFCV dataloaders to Composer trainer!

The current version of FFCV (0.0.3) has a bug where calling len(dataloader) does shuffling of image indices to load, making calls to len expensive. Composer calls len(dataloader) in the training loop for every batch and, hence, this is a performance hit. We fix it by patching the len function using ffcv_monkey_patches.

Note: Please make sure to run this fix (i.e., add it to the start of your training script) whenever training with Composer and FFCV!

[ ]:
from composer.datasets.ffcv_utils import ffcv_monkey_patches
ffcv_monkey_patches()

To get started with FFCV, we’ll convert the dataset to FFCV’s custom data format, which offers faster data loading.

Once this cell executes successfuly, you can find cifar_train.ffcv and cifar_val.ffcv in the data_directory directory.

[ ]:
from composer.datasets.ffcv_utils import write_ffcv_dataset
from torchvision.datasets import CIFAR10

# train dataset
ds = CIFAR10(root=data_directory, train=True, download=True)
write_ffcv_dataset(dataset=ds, write_path=data_directory + "/cifar_train.ffcv")

# validation dataset
ds = CIFAR10(root=data_directory, train=False, download=True)
write_ffcv_dataset(dataset=ds, write_path=data_directory + "/cifar_val.ffcv")

Now let’s construct FFCV train and test dataloaders. We’ll use similar transformations to those used for Torchvision datasets.

[ ]:
import ffcv
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder

# Please note that this mean/std is different from the mean/std used for the regular PyTorch dataloader as
# ToTensor does the normalization for PyTorch dataloaders.
cifar10_mean_ffcv = [125.307, 122.961, 113.8575]
cifar10_std_ffcv = [51.5865, 50.847, 51.255]
label_pipeline = [IntDecoder(), ffcv.transforms.ToTensor(), ffcv.transforms.Squeeze()]
image_pipeline = [SimpleRGBImageDecoder(), ffcv.transforms.ToTensor(),
                ffcv.transforms.ToTorchImage(channels_last=False, convert_back_int16=False),
                ffcv.transforms.Convert(torch.float32),
                transforms.Normalize(cifar10_mean_ffcv, cifar10_std_ffcv),
            ]

ffcv_train_dataloader = ffcv.Loader(
                data_directory + "/cifar_train.ffcv",
                batch_size=batch_size,
                num_workers=num_workers,
                order=ffcv.loader.OrderOption.RANDOM,
                pipelines={
                    'image': image_pipeline,
                    'label': label_pipeline
                },
                drop_last=True,
            )
ffcv_test_dataloader = ffcv.Loader(
                data_directory + "/cifar_val.ffcv",
                batch_size=batch_size,
                num_workers=num_workers,
                order=ffcv.loader.OrderOption.RANDOM,
                pipelines={
                    'image': image_pipeline,
                    'label': label_pipeline
                },
                drop_last=False,
            )

Now let’s instantiate our model, optimizer, and trainer again but with FFCV dataloaders. (No need to instantiate our scheduler again because it’s stateless!)

[ ]:
model = models.composer_resnet_cifar(model_name="resnet_20", num_classes=10)

optimizer = composer.optim.DecoupledSGDW(
    model.parameters(),
    lr=0.05,
    momentum=0.9,
    weight_decay=2.0e-3
)

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=ffcv_train_dataloader,
    eval_dataloader=ffcv_test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
)

And let’s get training!

[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
accelerated_time = end_time - start_time
print(f"It took {accelerated_time:0.4f} seconds to train with FFCV dataloaders")

Again, the runtime will vary based on the instance, but we found that this run with FFCV dataloaders took about 15-17 secs to run. So this is about ~1.3x faster and reaches the same ~62% accuracy. Please note that speedups from FFCV dataloaders are dependent on dataloading bottlenecks for your training run, i.e., you may not observe any speedup if your training run wasn’t dataloader bottlenecked.

What Next?#

Now you’re ready to integrate FFCV dataloaders to make Composer training even faster!

To help make the most of this tutorial, you may want to dig into FFCV itself, if you haven’t already.

In addition, please continue to explore our tutorials! Here’s a couple suggestions:

Come get involved with MosaicML!#

We’d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature you’d like? File an issue or make a pull request!