๐ Training with TPUs#
Composer provides beta support for single core training on TPUs. We integrate with the torch_xla
backend, for installation instructions and more details, see: https://github.com/pytorch/xla.
In this tutorial, we train a ResNet-20 on CIFAR10 using a single TPU core. The setup is exactly the same as with any other device, except the model must be moved to the device before passing to our Trainer
. We specify device=tpu
to enable the trainer to use TPUs.
As prerequisites, first install torch_xla
and the latest composer version.
[ ]:
%pip install cloud-tpu-client==0.10 torch==1.12.0 https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-1.12-cp37-cp37m-linux_x86_64.whl
%pip install mosaicml
from composer import Trainer
from composer import models
Next, we define the model and optimizer. TPUs require the model to be moved to the device before the optimizer is created, which we do here.
[ ]:
import torch
import torch_xla.core.xla_model as xm
model = models.composer_resnet_cifar(model_name='resnet_20', num_classes=10)
model = model.to(xm.xla_device())
optimizer = torch.optim.SGD(
model.parameters(),
lr=0.02,
momentum=0.9)
Creating the CIFAR10 dataset and dataloaders are exactly the same as with other non-TPU devices.
[ ]:
from torchvision import datasets, transforms
data_directory = "../data"
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)
batch_size = 1024
cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
test_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
Lastly, we train for 20 epochs on the TPU by simply adding device='tpu'
as an argument to the Trainer.
Note: we currently only support single-core TPUs in this beta release. Future release will include multi-core TPU support.
[ ]:
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
device="tpu",
eval_dataloader=test_dataloader,
optimizers=optimizer,
max_duration='20ep',
eval_interval=1,
)
trainer.fit()