composer.models#

The models module contains the ComposerModel base class along with reference implementations of many common models. Additionally, it includes task-specific convenience ComposerModels that wrap existing Pytorch models with standard forward passes and logging to enable quick interaction with the Trainer.

See Composer Model for more details.

Functions

composer_deeplabv3

Helper function to create a ComposerClassifier with a DeepLabv3(+) model. Logs

composer_efficientnetb0

Helper function to create a ComposerClassifier with an EfficientNet-b0 architecture.

composer_resnet

Helper function to create a ComposerClassifier with a torchvision ResNet model.

composer_resnet_cifar

Helper function to create a ComposerClassifier with a CIFAR ResNet models.

composer_timm

A wrapper around timm.create_model() used to create ComposerClassifier.

create_bert_classification

BERT classification model based on ๐Ÿค— Transformers.

create_bert_mlm

BERT model based on ๐Ÿค— Transformers.

create_gpt2

Implements HuggingFaceModel to wrap Hugging Face GPT-2 transformers. Logs training and

mnist_model

Helper function to create a ComposerClassifier with a simple convolutional neural network.

vit_small_patch16

Helper function to create a ComposerClassifier using a ViT-S/16 model.

Classes

ComposerClassifier

A convenience class that creates a ComposerModel for classification tasks from a vanilla PyTorch model.

HuggingFaceModel

A wrapper class that converts ๐Ÿค— Transformers models to composer models.

Initializer

Sets the initialization scheme for different layers of a PyTorch model.

SSD

Single Shot Object detection Model with pretrained ResNet34 backbone extending ComposerModel.

UNet

A U-Net model extending ComposerModel.

Hparams

These classes are used with yahp for YAML-based configuration.

BERTForClassificationHparams

YAHP interface for BERTModel.

BERTHparams

YAHP interface for BERTModel.

DeepLabV3Hparams

YAHP interface for

EfficientNetB0Hparams

YAHP interface for

GPT2Hparams

YAHP interface for GPT2Model.

MnistClassifierHparams

YAHP interface for mnist_model().

ModelHparams

General YAHP interface for ComposerModels.

ResNetCIFARHparams

Hparams interface for composer_resnet_cifar().

ResNetHparams

YAHP interface for composer_resnet().

SSDHparams

YAHP interface for SSD.

TimmHparams

YAHP interface for composer_timm().

UnetHparams

YAHP interface for UNet.

ViTSmallPatch16Hparams

YAHP interface for vit_small_batch16.