๐ Early Stopping#
In composer, Callbacks modify trainer behavior and are called at the relevant Events in the training loop. This tutorial focuses on two callbacks, the EarlyStopper and ThresholdStopper, both of which halt training early depending on different criteria.
Setup#
In this tutorial, weโll train a ComposerModel and halt training for criteria that weโll set. Weโll use the same model as in the Getting Started tutorial.
Install Composer#
First, install Composer if you havenโt already:
[ ]:
%pip install mosaicml
Seed#
Next, weโll set the seed for reproducibility:
[ ]:
from composer.utils.reproducibility import seed_all
seed_all(42)
Dataloader Setup#
[ ]:
import torch.utils.data
from torchvision import datasets, transforms
data_directory = "./data"
# Normalization constants
mean = (0.507, 0.487, 0.441)
std = (0.267, 0.256, 0.276)
batch_size = 1024
cifar10_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
train_dataset = datasets.CIFAR10(data_directory, train=True, download=True, transform=cifar10_transforms)
eval_dataset = datasets.CIFAR10(data_directory, train=False, download=True, transform=cifar10_transforms)
# Setting shuffle=False to allow for easy overfitting in this example
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
Model, Optimizer, Scheduler, and Evaluator Setup#
[ ]:
from composer import models
from composer.optim import DecoupledSGDW, LinearWithWarmupScheduler
from composer.core import Evaluator
from torchmetrics.classification.accuracy import Accuracy
model = models.ComposerResNetCIFAR(model_name='resnet_56', num_classes=10)
optimizer = DecoupledSGDW(
model.parameters(), # Model parameters to update
lr=0.05, # Peak learning rate
momentum=0.9,
weight_decay=2.0e-3 # If this looks large, it's because its not scaled by the LR as in non-decoupled weight decay
)
lr_scheduler = LinearWithWarmupScheduler(
t_warmup="1ep", # Warm up over 1 epoch
alpha_i=1.0, # Flat LR schedule achieved by having alpha_i == alpha_f
alpha_f=1.0
)
evaluator = Evaluator(
dataloader = eval_dataloader,
label = "eval",
metrics = Accuracy()
)
EarlyStopper#
The EarlyStopper
callback tracks a particular training or evaluation metric and stops training if the metric does not improve within a given time interval.
The callback takes the following parameters:
monitor
: The name of the metric to trackdataloader_label
: Thedataloader_label
identifies which dataloader the metric belongs to. By default, the train dataloader is labeledtrain
, and the evaluation dataloader is labeledeval
. (These names can be customized via thetrain_dataloader_label
in the Trainer or thelabel
argument of the Evaluator, respectively.)patience
: The interval of the time that the callback will wait before stopping training if the metric is not improving. You can use integers to specify the number of epochs or provide a Time string โ e.g., โ50baโ or โ2epโ for 50 batches and 2 epochs, respectively.min_delta
: If non-zero, the change in the tracked metric over thepatience
window must be at least this large.comp
: A comparison operator can be provided to measure the change in the monitored metric. The comparison operator will be called likecomp(current_value, previous_best)
See the API Reference for more information.
Here, weโll use it track the Accuracy metric over one epoch on the test dataset:
[ ]:
from composer.callbacks import EarlyStopper
early_stopper = EarlyStopper(monitor="Accuracy", dataloader_label="eval", patience=1)
Now that we have our callback, we can instantiate the Trainer and train:
[ ]:
from composer.trainer import Trainer
# Early stopping should stop training before we reach 100 epochs!
train_epochs = "100ep"
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=evaluator,
max_duration=train_epochs,
optimizers=optimizer,
schedulers=lr_scheduler,
callbacks=[early_stopper],
train_subset_num_batches=10, # only training on a subset of the data to trigger the callback sooner
)
# Train!
trainer.fit()
ThresholdStopper#
The ThresholdStopper callback is similar to the EarlyStopper, but it halts training when the metric crosses a threshold set in the ThresholdStopper callback.
This callback takes the following parameters: - monitor
, dataloader_label
, and comp
: Same as the EarlyStopper callback - threshold
: The float threshold that dictates when the halt training. - stop_on_batch
: If True, training will halt in the middle of a batch if the training metrics satisfy the threshold.
We will reuse the same setup for the ThresholdStopper example.
[ ]:
from composer.callbacks import ThresholdStopper
threshold_stopper = ThresholdStopper("Accuracy", "eval", threshold=0.3)
# Threshold stopping should stop training before we reach 100 epochs!
train_epochs = "100ep"
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=evaluator,
max_duration=train_epochs,
optimizers=optimizer,
schedulers=lr_scheduler,
callbacks=[threshold_stopper],
train_subset_num_batches=10, # only training on a subset of the data to trigger the callback sooner
)
# Train!
trainer.fit()