Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿ”Œ Training with TPUs#

Composer provides beta support for single core training on TPUs. We integrate with the torch_xla backend, for installation instructions and more details, see: https://github.com/pytorch/xla.

In this tutorial, we train a ResNet-20 on CIFAR10 using a single TPU core. The setup is exactly the same as with any other device, except the model must be moved to the device before passing to our Trainer. We specify device=tpu to enable the trainer to use TPUs.

As prerequisites, first install torch_xla and the latest composer version.

[ ]:
%pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
%pip install mosaicml

from composer import Trainer
from composer import models

Next, we define the model and optimizer. TPUs require the model to be moved to the device before the optimizer is created, which we do here.

[ ]:
import torch
import torch_xla.core.xla_model as xm

model = models.composer_resnet_cifar(model_name='resnet_20', num_classes=10)
model = model.to(xm.xla_device())

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.02,
    momentum=0.9)

Creating the CIFAR10 dataset and dataloaders are exactly the same as with other non-TPU devices.

[ ]:
from torchvision import datasets, transforms

data_directory = "../data"

# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)

batch_size = 1024

cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])

train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Lastly, we train for 20 epochs on the TPU by simply adding device='tpu' as an argument to the Trainer.

Note: we currently only support single-core TPUs in this beta release. Future release will include multi-core TPU support.

[ ]:
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    device="tpu",
    eval_dataloader=test_dataloader,
    optimizers=optimizer,
    max_duration='20ep',
    eval_interval=1,
)

trainer.fit()