composer.trainer
Trainer
is used to train models with Algorithm
instances.
The Trainer
is highly customizable and can support a wide variety of workloads.
Examples
# Setup dependencies
from composer.datasets import MNISTDatasetHparams
from composer.models.mnist import MnistClassifierHparams
model = MnistClassifierHparams(num_classes=10).initialize_objeect()
train_dataloader_spec = MNISTDatasetHparams(is_train=True,
datadir="./mymnist",
download=True).initialize_object()
train_dataloader_spec = MNISTDatasetHparams(is_train=False,
datadir="./mymnist",
download=True).initialize_object()
# Create a trainer that will checkpoint every epoch
# and train the model
trainer = Trainer(model=model,
train_dataloader_spec=train_dataloader_spec,
eval_dataloader_spec=eval_dataloader_spec,
max_epochs=50,
train_batch_size=128,
eval_batch_size=128,
checkpoint_interval_unit="ep",
checkpoint_folder="checkpoints",
checkpoint_interval=1)
trainer.fit()
# Load a trainer from the saved checkpoint and resume training
trainer = Trainer(model=model,
train_dataloader_spec=train_dataloader_spec,
eval_dataloader_spec=eval_dataloader_spec,
max_epochs=50,
train_batch_size=128,
eval_batch_size=128,
checkpoint_filepath="checkpoints/first_checkpoint.pt")
trainer.fit()
from composer.trainer import TrainerHparamms
# Create a trainer from hparams and train train the model
trainer = Trainer.create_from_hparams(hparams=hparams)
trainer.fit()
Trainer Hparams
Trainer
can be constructed via either it’s __init__
(see below)
or
TrainerHparams.
Our yahp based system allows configuring the trainer and algorithms via either a yaml
file (see here for an example) or command-line arguments. Below is a table of all the keys that can be used.
For example, the yaml for algorithms
can include:
algorithms:
- blurpool
- layer_freezing
You can also provide overrides at command line:
python examples/run_mosaic_trainer.py -f composer/yamls/models/classify_mnist_cpu.yaml --algorithms blurpool layer_freezing --datadir ~/datasets
Algorithms
name |
algorithm |
---|---|
alibi |
|
augmix |
|
blurpool |
|
channels_last |
|
colout |
|
curriculum_learning |
|
cutout |
|
dummy |
|
ghost_batchnorm |
|
label_smoothing |
|
layer_freezing |
|
mixup |
|
no_op_model |
|
progressive_resizing |
|
randaugment |
|
sam |
|
scale_schedule |
|
selective_backprop |
|
squeeze_excite |
|
stochastic_depth |
|
swa |
Callbacks
name |
callback |
---|---|
benchmarker |
|
grad_monitor |
|
lr_monitor |
|
torch_profiler |
|
speed_monitor |
|
Datasets
name |
dataset |
---|---|
brats |
|
cifar10 |
|
imagenet |
|
lm |
|
mnist |
|
synthetic |
|
Devices
name |
device |
---|---|
cpu |
|
gpu |
Loggers
name |
logger |
---|---|
file |
|
tqdm |
|
wandb |
|
Models
name |
model |
---|---|
efficientnetb0 |
|
gpt2 |
|
mnist_classifier |
|
resnet18 |
|
resnet56_cifar10 |
|
resnet50 |
|
resnet101 |
|
unet |
Optimizers
name |
optimizer |
---|---|
adamw |
|
decoupled_adamw |
|
decoupled_sgdw |
|
radam |
|
rmsprop |
|
sgd |
Schedulers
name |
scheduler |
---|---|
constant |
|
cosine_decay |
|
cosine_warmrestart |
|
exponential |
|
multistep |
|
step |
|
warmup |