composer.trainer

Trainer is used to train models with Algorithm instances. The Trainer is highly customizable and can support a wide variety of workloads.

Examples

# Setup dependencies
from composer.datasets import MNISTDatasetHparams
from composer.models.mnist import MnistClassifierHparams
model = MnistClassifierHparams(num_classes=10).initialize_objeect()
train_dataloader_spec = MNISTDatasetHparams(is_train=True,
                                            datadir="./mymnist",
                                            download=True).initialize_object()
train_dataloader_spec = MNISTDatasetHparams(is_train=False,
                                            datadir="./mymnist",
                                            download=True).initialize_object()
# Create a trainer that will checkpoint every epoch
# and train the model
trainer = Trainer(model=model,
                  train_dataloader_spec=train_dataloader_spec,
                  eval_dataloader_spec=eval_dataloader_spec,
                  max_epochs=50,
                  train_batch_size=128,
                  eval_batch_size=128,
                  checkpoint_interval_unit="ep",
                  checkpoint_folder="checkpoints",
                  checkpoint_interval=1)
trainer.fit()
# Load a trainer from the saved checkpoint and resume training
trainer = Trainer(model=model,
                  train_dataloader_spec=train_dataloader_spec,
                  eval_dataloader_spec=eval_dataloader_spec,
                  max_epochs=50,
                  train_batch_size=128,
                  eval_batch_size=128,
                  checkpoint_filepath="checkpoints/first_checkpoint.pt")
trainer.fit()
from composer.trainer import TrainerHparamms

# Create a trainer from hparams and train train the model
trainer = Trainer.create_from_hparams(hparams=hparams)
trainer.fit()

Trainer Hparams

Trainer can be constructed via either it’s __init__ (see below) or TrainerHparams.

Our yahp based system allows configuring the trainer and algorithms via either a yaml file (see here for an example) or command-line arguments. Below is a table of all the keys that can be used.

For example, the yaml for algorithms can include:

algorithms:
    - blurpool
    - layer_freezing

You can also provide overrides at command line:

python examples/run_mosaic_trainer.py -f composer/yamls/models/classify_mnist_cpu.yaml --algorithms blurpool layer_freezing --datadir ~/datasets

Algorithms

name

algorithm

alibi

AlibiHparams

augmix

AugMixHparams

blurpool

BlurPoolHparams

channels_last

ChannelsLastHparams

colout

ColOutHparams

curriculum_learning

CurriculumLearningHparams

cutout

CutOutHparams

dummy

DummyHparams

ghost_batchnorm

GhostBatchNormHparams

label_smoothing

LabelSmoothingHparams

layer_freezing

LayerFreezingHparams

mixup

MixUpHparams

no_op_model

NoOpModelHparams

progressive_resizing

ProgressiveResizingHparams

randaugment

RandAugmentHparams

sam

SAMHparams

scale_schedule

ScaleScheduleHparams

selective_backprop

SelectiveBackpropHparams

squeeze_excite

SqueezeExciteHparams

stochastic_depth

StochasticDepthHparams

swa

SWAHparams

Callbacks

name

callback

benchmarker

BenchmarkerHparams

grad_monitor

GradMonitorHparams

lr_monitor

LRMonitorHparams

torch_profiler

TorchProfilerHparams

speed_monitor

SpeedMonitorHparams

Datasets

name

dataset

brats

BratsDatasetHparams

cifar10

CIFAR10DatasetHparams

imagenet

ImagenetDatasetHparams

lm

LMDatasetHparams

mnist

MNISTDatasetHparams

synthetic

SyntheticDatasetHparams

Devices

name

device

cpu

CPUDeviceHparams

gpu

GPUDeviceHparams

Loggers

name

logger

file

FileLoggerBackendHparams

tqdm

TQDMLoggerBackendHparams

wandb

WandBLoggerBackendHparams

Models

name

model

efficientnetb0

EfficientNetB0Hparams

gpt2

GPT2Hparams

mnist_classifier

MnistClassifierHparams

resnet18

ResNet18Hparams

resnet56_cifar10

CIFARResNetHparams

resnet50

ResNet50Hparams

resnet101

ResNet101Hparams

unet

UnetHparams

Optimizers

name

optimizer

adamw

AdamWHparams

decoupled_adamw

DecoupledAdamWHparams

decoupled_sgdw

DecoupledSGDWHparams

radam

RAdamHparams

rmsprop

RMSPropHparams

sgd

SGDHparams

Schedulers

name

scheduler

constant

ConstantLRHparams

cosine_decay

CosineAnnealingLRHparams

cosine_warmrestart

CosineAnnealingWarmRestartsHparams

exponential

ExponentialLRHparams

multistep

MultiStepLRHparams

step

StepLRHparams

warmup

WarmUpLRHparams

API Reference