Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

โšก Migrating from PTL#

PyTorch Lightning is a popular and very well designed framework for training deep learning models. If you are interested in trying our efficient algorithms and using the Composer trainer, the below is a quick guide on how to adapt your models.

Setup#

Letโ€™s get started! Weโ€™ll first install dependencies and define the data and model.

Install Dependencies#

If you havenโ€™t already, letโ€™s install Composer and PyTorch Lightning:

[ ]:
%pip install mosaicml pytorch-lightning

The Model#

In this section, weโ€™ll go through the process of migrating the Resnet18 on CIFAR10 model from PTL to Composer. We will be following the PTL example here.

First, some relevant imports, as well as creating the model as in the PTL tutorial.

[ ]:
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torch.optim.lr_scheduler import OneCycleLR

def create_model():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

Training data#

As is standard, we setup the training data for CIFAR10 using torchvision datasets.

[ ]:
import torch
import torch.utils.data
import torchvision

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ]
)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False)

PTL Lightning Module#

Following the PTL tutorial, we use the LitResnet model:

[ ]:
from torchmetrics.functional import accuracy

class LitResnet(LightningModule):
    def __init__(self, lr=0.05):
        super().__init__()
        self.save_hyperparameters()
        self.model = create_model()

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=5e-4,
        )
        steps_per_epoch = 45000 // 256
        scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                0.1,
                epochs=30,
                steps_per_epoch=steps_per_epoch,
            ),
            "interval": "step",
        }
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}

PTLModel = LitResnet(lr=0.05)

LitModel to Composer#

Notice that up to here, we have only used pytorch lightning code. Here we will transfer the PTL module to be compatible with Composer. There are a few major differences: * The training_step is broken into two parts, the forward and the loss methods. This is needed since our algorithms (such as label smoothing or selective backprop) sometimes need to intercept and modify the loss. * Optimizers and schedulers are passed directly to the Trainer during initialization. * Our forward step accepts as input the entire batch and has to take care of unpacking the batch.

For more information about the ComposerModel format, see our guide.

[ ]:
from torchmetrics.classification.accuracy import Accuracy
from composer.models.base import ComposerModel
PTLmodel = LitResnet(lr=0.05)

class MosaicResnet(ComposerModel):
    def __init__(self):
        super().__init__()
        self.model = create_model()
        self.acc = Accuracy()

    def loss(self, outputs, batch, *args, **kwargs):
        """
        Accepts the outputs from forward() and the batch
        """
        x, y = batch  # unpack the labels
        return F.nll_loss(outputs, y)

    def metrics(self, train):
        return self.acc

    def forward(self, batch):
        x, _ = batch
        y = self.model(x)
        return F.log_softmax(y, dim=1)

    def validate(self, batch):
        _, targets = batch
        outputs = self.forward(batch)
        return outputs, targets

Training#

We instantiate the Mosaic trainer similarly by specifying the model, dataloaders, optimizers, and max_duration (epochs). For more details on the trainer arguments, see our Using the Trainer guide.

Now you are ready to insert your algorithms! As an example, here we add the BlurPool algorithm.

[ ]:
from composer import Trainer
from composer.algorithms import BlurPool

model = MosaicResnet()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.05,
    momentum=0.9,
    weight_decay=5e-4,
)

steps_per_epoch = 45000 // 256

scheduler = OneCycleLR(
    optimizer,
    0.1,
    epochs=30,
    steps_per_epoch=steps_per_epoch,
)

trainer = Trainer(
    model=model,
    algorithms=[
        BlurPool(
            replace_convs=True,
            replace_maxpools=True,
            blur_first=True
        ),
    ],
    train_dataloader=train_dataloader,
    device="gpu" if torch.cuda.is_available() else "cpu",
    eval_dataloader=test_dataloader,
    optimizers=optimizer,
    schedulers=scheduler,
    step_schedulers_every_batch=True,  # interval should be step
    max_duration='2ep',
    eval_interval=1,
    train_subset_num_batches=1,
)
trainer.fit()