composer.trainer

Trainer is used to train models with Algorithm instances. The Trainer is highly customizable and can support a wide variety of workloads.

Examples

# Setup dependencies
from composer.datasets import MNISTDatasetHparams
from composer.models.mnist import MnistClassifierHparams
model = MnistClassifierHparams(num_classes=10).initialize_objeect()
train_dataloader_spec = MNISTDatasetHparams(is_train=True,
                                            datadir="./mymnist",
                                            download=True).initialize_object()
train_dataloader_spec = MNISTDatasetHparams(is_train=False,
                                            datadir="./mymnist",
                                            download=True).initialize_object()
# Create a trainer that will checkpoint every epoch
# and train the model
trainer = Trainer(model=model,
                  train_dataloader_spec=train_dataloader_spec,
                  eval_dataloader_spec=eval_dataloader_spec,
                  max_epochs=50,
                  train_batch_size=128,
                  eval_batch_size=128,
                  checkpoint_interval_unit="ep",
                  checkpoint_folder="checkpoints",
                  checkpoint_interval=1)
trainer.fit()
# Load a trainer from the saved checkpoint and resume training
trainer = Trainer(model=model,
                  train_dataloader_spec=train_dataloader_spec,
                  eval_dataloader_spec=eval_dataloader_spec,
                  max_epochs=50,
                  train_batch_size=128,
                  eval_batch_size=128,
                  checkpoint_filepath="checkpoints/first_checkpoint.pt")
trainer.fit()
from composer.trainer import TrainerHparamms

# Create a trainer from hparams and train train the model
trainer = Trainer.create_from_hparams(hparams=hparams)
trainer.fit()

Trainer Hparams

Trainer can be constructed via either it’s __init__ (see below) or TrainerHparams.

Our yahp based system allows configuring the trainer and algorithms via either a yaml file (see here for an example) or command-line arguments. Below is a table of all the keys that can be used.

For example, the yaml for algorithms can include:

algorithms:
    - blurpool
    - layer_freezing

You can also provide overrides at command line:

python examples/run_mosaic_trainer.py -f composer/yamls/models/classify_mnist_cpu.yaml --algorithms blurpool layer_freezing --datadir ~/datasets

Algorithms

name

algorithm

alibi

AlibiHparams

augmix

AugMixHparams

blurpool

BlurPoolHparams

channels_last

ChannelsLastHparams

colout

ColOutHparams

curriculum_learning

CurriculumLearningHparams

cutout

CutOutHparams

dummy

DummyHparams

ghost_batchnorm

GhostBatchNormHparams

label_smoothing

LabelSmoothingHparams

layer_freezing

LayerFreezingHparams

mixup

MixUpHparams

no_op_model

NoOpModelHparams

progressive_resizing

ProgressiveResizingHparams

randaugment

RandAugmentHparams

sam

SAMHparams

scale_schedule

ScaleScheduleHparams

selective_backprop

SelectiveBackpropHparams

squeeze_excite

SqueezeExciteHparams

stochastic_depth

StochasticDepthHparams

swa

SWAHparams

Callbacks

name

callback

benchmarker

BenchmarkerHparams

grad_monitor

GradMonitorHparams

lr_monitor

LRMonitorHparams

pytorch_profiler

TorchProfilerHparams

speed_monitor

SpeedMonitorHparams

Datasets

name

dataset

brats

BratsDatasetHparams

cifar10

CIFAR10DatasetHparams

imagenet

ImagenetDatasetHparams

lm

LMDatasetHparams

mnist

MNISTDatasetHparams

synthetic

SyntheticDatasetHparams

Devices

name

device

cpu

CPUDeviceHparams

gpu

GPUDeviceHparams

Loggers

name

logger

file

FileLoggerBackendHparams

tqdm

TQDMLoggerBackendHparams

wandb

WandBLoggerBackendHparams

Models

name

model

efficientnetb0

EfficientNetB0Hparams

gpt2

GPT2Hparams

mnist_classifier

MnistClassifierHparams

resnet18

ResNet18Hparams

resnet56_cifar10

CIFARResNetHparams

resnet50

ResNet50Hparams

resnet101

ResNet101Hparams

unet

UnetHparams

Optimizers

name

optimizer

adamw

AdamWHparams

decoupled_adamw

DecoupledAdamWHparams

decoupled_sgdw

DecoupledSGDWHparams

radam

RAdamHparams

rmsprop

RMSPropHparams

sgd

SGDHparams

Schedulers

name

scheduler

constant

ConstantLRHparams

cosine_decay

CosineAnnealingLRHparams

cosine_warmrestart

CosineAnnealingWarmRestartsHparams

exponential

ExponentialLRHparams

multistep

MultiStepLRHparams

step

StepLRHparams

warmup

WarmUpLRHparams

API Reference

class composer.trainer.Trainer(*, model: composer.models.base.BaseMosaicModel, train_dataloader_spec: composer.datasets.hparams.DataloaderSpec, eval_dataloader_spec: composer.datasets.hparams.DataloaderSpec, max_epochs: int, train_batch_size: int, eval_batch_size: int, algorithms: Optional[List[composer.core.algorithm.Algorithm]] = None, optimizer_hparams: Optional[composer.optim.optimizer_hparams.OptimizerHparams] = None, schedulers_hparams: Optional[Union[composer.optim.scheduler.SchedulerHparams, List[composer.optim.scheduler.SchedulerHparams]]] = None, device: Optional[composer.trainer.devices.device.Device] = None, grad_accum: int = 1, grad_clip_norm: Optional[float] = None, validate_every_n_batches: int = - 1, validate_every_n_epochs: int = 1, compute_training_metrics: bool = False, precision: composer.core.precision.Precision = Precision.FP32, num_workers: int = 0, prefetch_factor: int = 2, persistent_workers: bool = False, pin_memory: bool = False, timeout: int = 0, ddp_store_hparams: Optional[composer.trainer.ddp.StoreHparams] = None, fork_rank_0: bool = False, seed: Optional[int] = None, deterministic_mode: bool = False, log_destinations: Optional[List[composer.core.logging.base_backend.BaseLoggerBackend]] = None, callbacks: Sequence[composer.core.callback.Callback] = (), checkpoint_filepath: Optional[str] = None, checkpoint_interval_unit: Optional[str] = None, checkpoint_folder: Optional[str] = 'checkpoints', checkpoint_interval: Optional[int] = 1, config: Optional[Dict[str, Any]] = None)[source]

Trainer for training a model with algorithms.

Can be created either with __init__ or by providing a TrainerHparams object (see create_from_hparams()).

Parameters
  • model (BaseMosaicModel) – The model to train.

  • train_dataloader_spec (DataloaderSpec) – The dataloader spec for the training data.

  • eval_dataloader_spec (DataloaderSpec) – The dataloader spec for the evaluation data.

  • max_epochs (int) – The maxmimum number of epochs to train for.

  • train_batch_size (int) – Minibatch size for training data.

  • eval_batc_size (int) – Minibatch size for evaluation data.

  • algorithms (List[Algorithm], optional) – The algorithms to use during training. (default: [])

  • optimizer_hparams – (OptimizerHparams, optional): The OptimizerHparams for constructing the optimizer for training. Must pass OptimizerHparams instead of a torch.optim.Optimizer object because the optimizer has to be constructed after certain algorithms which modify the model architecture have run on the model. (default: MosaicMLSGDWHparams(lr=0.1, momentum=0.9, weight_decay=1.0e-4))

  • schedulers_hparams – (Union[SchedulerHparams, List[SchedulerHparams]], optional): The SchedulerHparams for constructing the one or more learning rate schedulers used during training. Must pass SchedulerHparams instead of a torch.optim.lr_scheduler._LRScheduler object because the scheduler needs an optimizer to be constructed and we construct the optimizer in __init__. (default: [CosineAnnealingLRHparams(T_max=f"{max_epochs}ep"), WarmUpLRHparams()]).

  • device (Device, optional) – The device to use for training. Either DeviceCPU or DeviceGPU. (default DeviceCPU(n_cpus=1))

  • grad_accum (int, optional) – The number of microbatches to split a per-device batch into. Gradients are summed over the microbatches per device. (default: 1)

  • grad_clip_norm (float, optional) – The norm to clip gradient magnitudes to. Set to None for no gradient clipping. (default: None)

  • validate_every_n_batches (int, optional) – Compute metrics on evaluation data every N batches. Set to -1 to never validate on a batchwise frequency. (default: -1)

  • validate_every_n_epochs (int, optional) – Compute metrics on evaluation data every N epochs. Set to -1 to never validate on a epochwise frequency. (default: 1)

  • compute_training_metrics (bool, optional) – True to compute metrics on training data and False to not. (default: False)

  • precision (Precision, optional) – Numerical precision to use for training. (default: Precision.FP32).

  • num_workers (int, optional) – The number of CPU workers to use per GPU. 0 results in loading data on the main process. (default: 0)

  • prefetch_factor (int, optional) – Number of samples loaded in advance by each worker. (default: 2)

  • persistent_workers (bool, optional) – Whether or not to shutdown workers after the dataset has been consumed once. (default: False)

  • pin_memory (bool, optional) – Whether or not to copy data tensors into CUDA pinned memory. (default: False)

  • timeout (int, optional) – Timeout value for collecting a batch from workers. 0 for no timeout. (default: 0)

  • ddp_store_hparams (StoreHparams, optional) – DistributedDataParallel configuration. (default: TCPStoreHparams("127.0.0.1", 43297))

  • fork_rank_0 (bool, optional) – True to fork the rank 0 process in distributed data parallel, False to not. (default: True)

  • seed (int, optional) – The seed used in randomization. When not provided a random seed will be created. (default: None)

  • deterministic_mode (bool, optional) – Run the model deterministically. Experimental. Performance degradations expected. Certain Torch modules may not have deterministic implementations, which will result in a crash. (default: False)

  • log_destinations (List[BaseLoggerBackend], optional) – The destinations to log training information to. (default [TQDMLoggerBackend()]).

  • callbacks (Sequence[Callback], optional) – The callbacks to run during training. (default: [])

  • checkpoint_filepath (str, optional) – The path to a trainer checkpoint file. If provided the trainer will load the state (along with it’s associated attributes) during initialization. (default: None)

  • checkpoint_interval_unit (int, optional) – Unit for the checkpoint save interval – should be ‘ep’ for epochs, ‘ba’ for batches, or None to disable checkpointing. (default: None).

  • checkpoint_folder (str, optional) – The folder to save checkpoints to. (default: checkpoints)

  • checkpoint_interval (int, optional) – The frequency with which to checkpoint. (default: 1)

  • config (Dict[str, Any], optional) – Extra user-provided trainer configuration. Will be persisted along with the trainer state during checkpointing. (default: None)

state

The State object used to store training state.

Type

State

logger

The Logger used for logging.

Type

Logger

engine

The Engine used for running callbacks and algorithms.

Type

Engine

classmethod create_from_hparams(hparams: composer.trainer.trainer_hparams.TrainerHparams) composer.trainer.trainer.Trainer[source]

Instantiate a Trainer using a TrainerHparams object.

Parameters

hparams (TrainerHparams) – The TrainerHparams object used to instantiate the trainer.

Returns

A Trainer object initialized with the provided TrainerHparams.

eval(is_batch: bool)[source]

Evaluate the model on the provided evaluation data and log appropriate metrics.

Parameters

is_batch (bool) – True to log metrics with LogLevel.BATCH and False to log metrics with LogLevel.EPOCH.

fit()[source]

Train and evaluate the model on the provided data.