Source code for composer.models.classify_mnist.model

# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""A simple convolutional neural network extending :class:`.ComposerClassifier`."""

from typing import List, Optional, Sequence, Union

import torch
import torch.nn as nn
from torch.nn import functional as F

from composer.models.initializers import Initializer
from composer.models.tasks import ComposerClassifier

__all__ = ['Model', 'mnist_model']


class Model(nn.Module):
    """Toy convolutional neural network architecture in pytorch for MNIST."""

    def __init__(self, initializers: Sequence[Union[str, Initializer]], num_classes: int = 10):
        super().__init__()

        self.num_classes = num_classes

        for initializer in initializers:
            initializer = Initializer(initializer)
            self.apply(initializer.get_initializer())

        self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)
        self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)
        self.bn = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 16, 32)
        self.fc2 = nn.Linear(32, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn(out)
        out = F.relu(out)
        out = F.adaptive_avg_pool2d(out, (4, 4))
        out = torch.flatten(out, 1, -1)
        out = self.fc1(out)
        out = F.relu(out)
        return self.fc2(out)


[docs]def mnist_model(num_classes: int = 10, initializers: Optional[List[Initializer]] = None): """Helper function to create a :class:`.ComposerClassifier` with a simple convolutional neural network. Args: num_classes (int, optional): The number of classes. Needed for classification tasks. Default: ``10`` initializers (List[Initializer], optional): list of Initializers for the model. ``None`` for no initialization. Default: ``None`` Returns: ComposerModel: instance of :class:`.ComposerClassifier` with a simple MNIST model. Example: .. testcode:: from composer.models import mnist_model model = mnist_model() """ if initializers is None: initializers = [] model = Model(initializers, num_classes) composer_model = ComposerClassifier(module=model) return composer_model