composer.models.base#

The ComposerModel base interface.

Classes

ComposerModel

The interface needed to make a PyTorch model compatible with composer.Trainer.

class composer.models.base.ComposerModel[source]#

Bases: torch.nn.modules.module.Module, abc.ABC

The interface needed to make a PyTorch model compatible with composer.Trainer.

To create a Trainer-compatible model, subclass ComposerModel and implement forward() and loss(). For full functionality (logging and validation), implement metrics() and validate().

See the Composer Model walk through for more details.

Minimal Example:

import torchvision
import torch.nn.functional as F

from composer.models import ComposerModel

class ResNet18(ComposerModel):

    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18() # define PyTorch model in __init__.

    def forward(self, batch): # batch is the output of the dataloader
        # specify how batches are passed through the model
        inputs, _ = batch
        return self.model(inputs)

    def loss(self, outputs, batch):
        # pass batches and `forward` outputs to the loss
        _, targets = batch
        return F.cross_entropy(outputs, targets)
logger#

The training Logger. The trainer sets the Logger on the:attr:~composer.core.event.Event.INIT event.

Type

Optional[Logger]

abstract forward(batch)[source]#

Compute model output given a batch from the dataloader.

Parameters

batch (Batch) โ€“ The output batch from dataloader.

Returns

Tensor | Sequence[Tensor] โ€“ The result that is passed to loss() as the parameter outputs.

Warning

This method is different from vanilla PyTorch model.forward(x) or model(x) as it takes a batch of data that has to be unpacked.

Example:

def forward(self, batch): # batch is the output of the dataloader
    inputs, _ = batch
    return self.model(inputs)

The outputs of forward() are passed to loss() by the trainer:

for batch in train_dataloader:
    optimizer.zero_grad()
    outputs = model.forward(batch)
    loss = model.loss(outputs, batch)
    loss.backward()
abstract loss(outputs, batch, *args, **kwargs)[source]#

Compute the loss of the model given outputs from forward() and a Batch of data from the dataloader. The Trainer will call .backward() on the returned loss.

Parameters
  • outputs (Any) โ€“ The output of the forward pass.

  • batch (Batch) โ€“ The output batch from dataloader.

Returns

Tensor | Sequence[Tensor] โ€“ The loss as a torch.Tensor.

Example:

import torch.nn.functional as F

def loss(self, outputs, batch):
    # pass batches and :meth:`forward` outputs to the loss
     _, targets = batch # discard inputs from batch
    return F.cross_entropy(outputs, targets)

The outputs of forward() are passed to loss() by the trainer:

for batch in train_dataloader:
    optimizer.zero_grad()
    outputs = model.forward(batch)
    loss = model.loss(outputs, batch)
    loss.backward()
metrics(train=False)[source]#

Get metrics for evaluating the model. Metrics should be instances of torchmetrics.Metric defined in __init__(). This format enables accurate distributed logging. Metrics consume the outputs of validate(). To track multiple metrics, return a list of metrics in a MetricCollection.

Parameters

train (bool, optional) โ€“ True to return metrics that should be computed during training and False otherwise. This flag is set automatically by the Trainer. Default: False.

Returns

Metric or MetricCollection โ€“ An instance of Metric or MetricCollection.

Warning

Each metric keeps states which are updated with data seen so far. As a result, different metric instances should be used for training and validation. See: https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for more details.

Example:

from torchmetrics.classification import Accuracy
from composer.models.loss import CrossEntropyLoss

def __init__(self):
    super().__init__()
    self.train_acc = Accuracy() # torchmetric
    self.val_acc = Accuracy()
    self.val_loss = CrossEntropyLoss()

def metrics(self, train: bool = False):
    return self.train_acc if train else MetricCollection([self.val_acc, self.val_loss])
validate(batch)[source]#

Compute model outputs on provided data. Will be called by the trainer with torch.no_grad enabled.

The output of this function will be directly used as input to all metrics returned by metrics().

Parameters

batch (Batch) โ€“ The output batch from dataloader

Returns

Tuple[Any, Any] โ€“ A Tuple of (outputs, targets) that is passed directly to the update() methods of the metrics returned by metrics().

Example:

def validate(self, batch): # batch is the output of the dataloader
    inputs, targets = batch
    outputs = self.model(inputs)
    return outputs, targets # return a tuple of (outputs, targets)

This pseudocode illustrates how validate() outputs are passed to metrics():

metrics = model.metrics(train=False) # get torchmetrics

for batch in val_dataloader:
    outputs, targets = model.validate(batch)
    metrics.update(outputs, targets)  # update metrics with output, targets for each batch

metrics.compute() # compute final metrics