Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐ŸŽ๏ธ FFCV DataLoaders#

This notebook will walk you through an example of using FFCV dataloaders with Composer training loop. Weโ€™ll first run Composer training loop with PyTorch dataloaders and then the same run with FFCV dataloaders for comparing performance between the two runs. Weโ€™ll be using CIFAR10 dataset for demonstration purposes but you can use ImageNet-1K (and others) as well.

Note: This notebook may not work in Google colab due to FFCVโ€™s requirement for Python >= 3.8 and Google colab are running Python 3.7 as of May 15, 2022.

Install Composer#

Weโ€™ll start by installing composer and ffcv:

[ ]:
!apt update && apt install -y --no-install-recommends libopencv-dev libturbojpeg-dev
!cp -f /usr/lib/x86_64-linux-gnu/pkgconfig/opencv.pc /usr/lib/x86_64-linux-gnu/pkgconfig/opencv4.pc
!pip install mosaicml ffcv numba opencv-python
[ ]:
import torch
cuda_ver = torch.version.cuda.replace(".", "")
!pip install cupy-cuda{cuda_ver}

Set Up Our Workspace#

Imports#

In this section weโ€™ll set up our workspace. Weโ€™ll import the necessary packages, and setup our dataloader and trainer. First, the imports:

[ ]:
import time

import composer
from torchvision import datasets, transforms

torch.manual_seed(42) # For replicability

Dataset & DataLoader#

Next, we instantiate our CIFAR10 dataset and dataloader. Weโ€™ll use the Torchvision CIFAR10 and PyTorch dataloader for the sake of familiarity.

[ ]:
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024
num_workers = 2
data_directory = "/tmp"

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               num_workers=num_workers,
                                               batch_size=batch_size,
                                               pin_memory=True,
                                               drop_last=True,
                                               shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                              num_workers=num_workers,
                                              batch_size=batch_size,
                                              pin_memory=True,
                                              drop_last=False,
                                              shuffle=False)

Model#

Next, we create our model. Weโ€™re using composerโ€™s built-in ResNet18. To use your own custom model, please see the custom models tutorial.

[ ]:
from composer import models
model = models.ComposerResNetCIFAR(model_name='resnet_20', num_classes=10)

Optimizer and Scheduler#

The trainer will handle instantiating the optimizer, but first we need to create the optimizer and LR scheduler. Weโ€™re using MosaicMLโ€™s SGD with decoupled weight decay:

[ ]:
optimizer = composer.optim.DecoupledSGDW(
    model.parameters(), # Model parameters to update
    lr=0.05, # Peak learning rate
    momentum=0.9,
    weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)

To keep the runtime short, weโ€™ll train our baseline model for five epochs. The first epoch will be linear warmup, followed by four epochs of constant LR. We achieve this by instantiating a LinearWithWarmupScheduler class. Feel free to increase the number of epochs in case you want to see the impact of running it for a longer duration.

[ ]:
lr_scheduler = composer.optim.LinearWithWarmupScheduler(
    t_warmup="1ep", # Warm up over 1 epoch
    alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f
    alpha_f=1.0
)

Train a Baseline Model#

And now we create our trainer: Note: We want to gpu as a device because FFCV works the best on GPU-capable machines.

[ ]:
train_epochs = "5ep" # Train for 5 epochs
device = "gpu"

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device
)

We train and measure the training time below.

[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
print(f"It took {end_time - start_time:0.4f} seconds to train")

Depending on where you are running this notebook, the runtime may vary based on the machine status. We found that the five epochs of training could take anywhere from 23-25 seconds to run and the mean validation accuracy was typically is close to ~62%.

Use FFCV dataloaders to Speed Up Training#

Next, we convert dataset to a format used by FFCV. FFCV uses itโ€™s own data format suitable for faster dataloading. Once this cell executes successfuly, you can find cifar_train.ffcv and cifar_val.ffcv in data_directory directory.

[ ]:
from composer.datasets.ffcv_utils import write_ffcv_dataset
from torchvision.datasets import CIFAR10


# Train dataset
ds = CIFAR10(root=data_directory, train=True, download=True)
write_ffcv_dataset(dataset=ds, write_path=data_directory + "/cifar_train.ffcv")

# validation dataset
ds = CIFAR10(root=data_directory, train=False, download=True)
write_ffcv_dataset(dataset=ds, write_path=data_directory + "/cifar_val.ffcv")

Current version of ffcv (0.0.3) has a bug where calling len(dataloader) does shuffling of image indices to load, therefore, calls to len are expensive. Composer calls len(dataloader) function in training loop for every batch and, hence, this is a performance hit. We fix it by patching the len function using ffcv_monkey_patches.

[ ]:
from composer.datasets.ffcv_utils import ffcv_monkey_patches
ffcv_monkey_patches()

Now let us construct FFCV train and test dataloaders. We use the similar transformations as used for TorchVision datasets.

[ ]:
import ffcv
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder

# Please note that this mean/std is different from the mean/std used for regular PyTorch dataloader as
# ToTensor does the normalization for PyTorch dataloaders.
cifar10_mean_ffcv = [125.307, 122.961, 113.8575]
cifar10_std_ffcv = [51.5865, 50.847, 51.255]
label_pipeline = [IntDecoder(), ffcv.transforms.ToTensor(), ffcv.transforms.Squeeze()]
image_pipeline = [SimpleRGBImageDecoder(), ffcv.transforms.ToTensor(),
                ffcv.transforms.ToTorchImage(channels_last=False, convert_back_int16=False),
                ffcv.transforms.Convert(torch.float32),
                transforms.Normalize(cifar10_mean_ffcv, cifar10_std_ffcv),
            ]

ffcv_train_dataloader = ffcv.Loader(
                data_directory + "/cifar_train.ffcv",
                batch_size=batch_size,
                num_workers=num_workers,
                order=ffcv.loader.OrderOption.RANDOM,
                pipelines={
                    'image': image_pipeline,
                    'label': label_pipeline
                },
                drop_last=True,
            )
ffcv_test_dataloader = ffcv.Loader(
                data_directory + "/cifar_val.ffcv",
                batch_size=batch_size,
                num_workers=num_workers,
                order=ffcv.loader.OrderOption.RANDOM,
                pipelines={
                    'image': image_pipeline,
                    'label': label_pipeline
                },
                drop_last=False,
            )

Now letโ€™s instantiate our model, optimizer, and trainer again but with FFCV dataloaders. No need to instantiate our scheduler again because itโ€™s stateless!

[ ]:
model = models.ComposerResNetCIFAR(model_name="resnet_20", num_classes=10)

optimizer = composer.optim.DecoupledSGDW(
    model.parameters(),
    lr=0.05,
    momentum=0.9,
    weight_decay=2.0e-3
)

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=ffcv_train_dataloader,
    eval_dataloader=ffcv_test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
)

And letโ€™s get training!

[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
accelerated_time = end_time - start_time
print(f"It took {accelerated_time:0.4f} seconds to train with FFCV dataloaders")

Again, the runtime will vary based on the instance, but we found that this run with FFCV dataloaders took about 15-17 secs to run. So this is about ~1.3x faster and reaches the same ~62% accuracy. Please note that speed-up from FFCV dataloaders are dependent on dataloading bottleneck for your training run, i.e., you may not observe any speed-up if your training run wasnโ€™t dataloader bottlenecked.

You did it! Now come get involved with MosaicML!#

This is the end of this notebook but weโ€™d love for you to get involved with MosaicML community in any of these ways:

Star Composer on GitHub#

Stay up-to-date and help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature youโ€™d like? File an issue or make pull request!