Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

🖼️ Getting Started#

This notebook will walk you through how to accelerate model training with Composer. We’ll start by training a baseline ResNet56 on CIFAR10, then see how training efficiency improves as we add speed-up methods.

Install Composer#

We’ll start by installing composer

[ ]:
%pip install mosaicml

Set Up Our Workspace#

Imports#

In this section we’ll set up our workspace. We’ll import the necessary packages, and setup our dataset and trainer. First, the imports:

[ ]:
import time

import torch
import torch.utils.data

import composer
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
from composer.loggers import InMemoryLogger, LogLevel
from composer.core.time import Time, Timestamp

torch.manual_seed(42) # For replicability

Dataset & DataLoader#

Next, we instantiate our CIFAR10 dataset and dataloader. Composer has it’s own CIFAR10 dataset and dataloaders, but this walkthrough focuses on how to use Composer’s algorithms, so we’ll stick with the Torchvision CIFAR10 and PyTorch dataloader for the sake of familiarity.

[ ]:
data_directory = "../data"

# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Logging#

Next, we instantiate an InMemoryLogger that records all the data from the Composer Trainer. We will use this logger to generate data plots after we complete training.

[ ]:
logger = InMemoryLogger(log_level=LogLevel.BATCH)

Model#

Next, we create our model. We’re using composer’s built-in ResNet56. To use your own custom model, please see the custom models tutorial.

[ ]:
from composer import models
model = models.ComposerResNetCIFAR(model_name='resnet_56', num_classes=10)

Optimizer and Scheduler#

The trainer will handle instantiating the optimizer, but first we need to create the optimizer and LR scheduler. We’re using MosaicML’s SGD with decoupled weight decay:

[ ]:
optimizer = composer.optim.DecoupledSGDW(
    model.parameters(), # Model parameters to update
    lr=0.05, # Peak learning rate
    momentum=0.9,
    weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)

We’ll assume this is being run on Colab, which means training for hundreds of epochs would take a very long time. Instead we’ll train our baseline model for three epochs. The first epoch will be linear warmup, followed by two epochs of constant LR. We achieve this by instantiating a LinearWithWarmupScheduler class.

[ ]:
lr_scheduler = composer.optim.LinearWithWarmupScheduler(
    t_warmup="1ep", # Warm up over 1 epoch
    alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f
    alpha_f=1.0
)

Train a Baseline Model#

And now we create our trainer:

[ ]:
train_epochs = "3ep" # Train for 3 epochs because we're assuming Colab environment and hardware
device = "gpu" if torch.cuda.is_available() else "cpu" # select the device

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
    loggers=logger
)

We train and measure the training time below.

[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
print(f"It took {end_time - start_time:0.4f} seconds to train")

If you’re running this on Colab, the runtime will vary a lot based on the instance. We found that the three epochs of training could take anywhere from 120-550 seconds to run, and the mean validation accuracy was typically in the range of 25%-40%.

Extract and Plot Logged Data#

We now plot our validation accuracy…

[ ]:
timeseries_raw = logger.get_timeseries("metrics/eval/Accuracy")
plt.plot(timeseries_raw['epoch'], timeseries_raw["metrics/eval/Accuracy"])
plt.xlabel("Epoch")
plt.ylabel("Validation Accuracy")
plt.title("Accuracy per epoch without Composer")
plt.show()

Use Algorithms to Speed Up Training#

One of the things we’re most excited about at MosaicML is our speed-up algorithms. We used these algorithms to speed up training of ResNet50 on ImageNet by up to 3.4x. Let’s try applying a few algorithms to make our ResNet56 more efficient.

We’ll start with Label Smoothing, which serves as a form of regulation by interpolating between the target distribution and another distribution that usually has higher entropy.

[ ]:
label_smoothing = composer.algorithms.LabelSmoothing(0.1)

Let’s also use BlurPool, which increases accuracy by applying a spatial low-pass filter before the pool in max pooling and whenever using a strided convolution.

[ ]:
blurpool = composer.algorithms.BlurPool(
    replace_convs=True, # Blur before convs
    replace_maxpools=True, # Blur before max-pools
    blur_first=True # Blur before conv/max-pool
)

Our final algorithm in our improved training recipe is Progressive Image Resizing. Progressive Image Resizing initially shrinks the size of training images and slowly scales them back to their full size over the course of training. It increases throughput during the early phase of training, when the network may learn coarse-grained features that do not require details lost by reducing image resolution.

[ ]:
prog_resize = composer.algorithms.ProgressiveResizing(
    initial_scale=.6, # Size of images at the beginning of training = .6 * default image size
    finetune_fraction=0.34, # Train on default size images for 0.34 of total training time.
)

We’ll assemble all our algorithms into a list to pass to our trainer.

[ ]:
algorithms = [label_smoothing, blurpool, prog_resize]

Now let’s instantiate our model, optimizer, logger, and trainer again. No need to instantiate our scheduler again because it’s stateless!

[ ]:
model = models.ComposerResNetCIFAR(model_name="resnet_56", num_classes=10)

composer_logger = InMemoryLogger(log_level=LogLevel.BATCH)

optimizer = composer.optim.DecoupledSGDW(
    model.parameters(),
    lr=0.05,
    momentum=0.9,
    weight_decay=2.0e-3
)

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
    loggers=composer_logger,
    algorithms=algorithms # Adding algorithms this time
)

And let’s get training!

[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
three_epochs_accelerated_time = end_time - start_time
print(f"It took {three_epochs_accelerated_time:0.4f} seconds to train")

Again, the runtime will vary based on the instance, but we found that it took about 0.43x-0.75x as long to train (a 1.3x-2.3x speedup, which corresponds to 90-400 seconds) relative to the baseline recipe without augmentations. We also found that validation accuracy was similar for the algorithm-enhanced and baseline recipes.

Because Progressive Resizing increases data throughput (i.e. more samples per second), we can train for more iterations in the same amount of wall clock time. Let’s train our model for one additional epoch!

[ ]:
train_epochs = "1ep"

Resuming training means we’ll need to use a flat LR schedule:

[ ]:
lr_scheduler = composer.optim.scheduler.ConstantScheduler(alpha=1.0, t_max='1dur')

And we can also get rid of progressive resizing (because we want to train on the full size images for this additional epoch), and the model already has blurpool enabled, so we don’t need to pass that either:

[ ]:
algorithms = [label_smoothing]
[ ]:
composer_logger_1ep = InMemoryLogger(log_level=LogLevel.BATCH)

trainer = composer.trainer.Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    schedulers=lr_scheduler,
    device=device,
    loggers=composer_logger_1ep,
    algorithms=algorithms
)

start_time = time.perf_counter()
trainer.fit()

end_time = time.perf_counter()
final_epoch_accelerated_time = end_time - start_time
# Time for four epochs = time for three epochs + time for fourth epoch
four_epochs_accelerated_time = three_epochs_accelerated_time + final_epoch_accelerated_time
print(f"It took {four_epochs_accelerated_time:0.4f} seconds to train")

We found that using these speed-up algorithms for four epochs resulted in runtime similar to or less than three epochs without speed-up algorithms (120-550 seconds, depending on the instance), and that they usually improved validation accuracy by 5-15 percentage points, yielding validation accuracy in the range of 30%-50%.

Let’s plot the results from using Label Smoothing and Progressive Resizing!

[ ]:
# Original data
original_timeseries = logger.get_timeseries("metrics/eval/Accuracy")
original_epochs = original_timeseries['epoch']
original_acc = original_timeseries["metrics/eval/Accuracy"]

# Composer data
composer_timeseries = composer_logger.get_timeseries("metrics/eval/Accuracy")
composer_epochs = list(composer_timeseries["epoch"])
composer_acc = list(composer_timeseries["metrics/eval/Accuracy"])

# Concatenate 3 epochs with Label Smoothing and ProgRes with 1 epoch without ProgRes
composer_timeseries_1ep = composer_logger_1ep.get_timeseries("metrics/eval/Accuracy")
all_epochs = [composer_epochs[-1] + i for i in composer_timeseries_1ep["epoch"]]
composer_epochs.extend(all_epochs)
composer_acc.extend(composer_timeseries_1ep["metrics/eval/Accuracy"])

#Print mean validation accuracies
print("Original Validation Mean: " + str(sum(original_acc)/len(original_acc)))
print("Composer Validation Mean: " + str(sum(composer_acc)/len(composer_acc)))

# Plot both sets of data
plt.plot(original_epochs, original_acc, label="Original")
plt.plot(composer_epochs, composer_acc, label="Composer")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("Validation Accuracy")
plt.title("Accuracy and speed improvements with equivalent WCT")
plt.show()

You did it! Now come get involved with MosaicML!#

Hopefully you’re now comfortable with the basics of training with Composer. We’d love for you to get involved with MosaicML community in any of these ways:

Star Composer on GitHub#

Stay up-to-date and help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature you’d like? File an issue or make pull request!