โ๏ธ Callbacks#
Callbacks provide hooks that run at each training loopโs Event
.
By convention, callbacks should not modify the
training loop by changing the State
, but rather by reading and
logging various metrics. Typical callback use cases include logging, timing,
or model introspection.
Using Callbacks#
Built-in callbacks can be accessed in composer.callbacks
and
registered with the callbacks
argument to the Trainer
.
from composer import Trainer
from composer.callbacks import SpeedMonitor, LRMonitor
from composer.loggers import WandBLogger
Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=None,
max_duration='1ep',
callbacks=[SpeedMonitor(window_size=100), LRMonitor()],
loggers=[WandBLogger()],
)
This example includes callbacks that measure the model throughput and learning rate and logs them to Weights & Biases. Callbacks control what is being logged, whereas loggers specify where the information is being saved. For more information on loggers, see Logging.
Available Callbacks#
Composer provides several callbacks to monitor and log various components of training.
Callback to save checkpoints. |
|
Logs the training throughput. |
|
Logs the learning rate. |
|
Computes and logs the L2 norm of gradients on the |
|
Logs the memory usage of the model. |
Custom Callbacks#
Custom callbacks should inherit from Callback
and override any of the
event-related hooks. For example, below is a simple callback that runs on
EPOCH_START
and prints the epoch number.
from composer import Callback, State, Logger
class EpochMonitor(Callback):
def epoch_start(self, state: State, logger: Logger):
print(f'Epoch: {state.timer.epoch}')
Alternatively, one can override Callback.run_event()
to run code
at every event. The below is an equivalent implementation for EpochMonitor
:
from composer import Callback, Event, Logger, State
class EpochMonitor(Callback):
def run_event(self, event: Event, state: State, logger: Logger):
if event == Event.EPOCH_START:
print(f'Epoch: {state.timer.epoch}')
Warning
If Callback.run_event()
is overridden, the individual methods corresponding
to each event will be ignored.
The new callback can then be provided to the trainer.
from composer import Trainer
trainer = Trainer(
...,
callbacks=[EpochMonitor()]
)
Events#
Here is the list of supported Event
for callbacks to hook into.
- class composer.core.Event(value)[source]
Enum to represent events in the training loop.
The following pseudocode shows where each event fires in the training loop:
# <INIT> # <FIT_START> for epoch in range(NUM_EPOCHS): # <EPOCH_START> for inputs, targets in dataloader: # <AFTER_DATALOADER> # <BATCH_START> # <BEFORE_FORWARD> outputs = model.forward(inputs) # <AFTER_FORWARD> # <BEFORE_LOSS> loss = model.loss(outputs, targets) # <AFTER_LOSS> # <BEFORE_BACKWARD> loss.backward() # <AFTER_BACKWARD> optimizer.step() # <BATCH_END> if should_eval(batch=True): # <EVAL_START> # <EVAL_BATCH_START> # <EVAL_BEFORE_FORWARD> # <EVAL_AFTER_FORWARD> # <EVAL_BATCH_END> # <EVAL_END> # <BATCH_CHECKPOINT> # <EPOCH_END> if should_eval(batch=False): # <EVAL_START> # <EVAL_BATCH_START> # <EVAL_BEFORE_FORWARD> # <EVAL_AFTER_FORWARD> # <EVAL_BATCH_END> # <EVAL_END> # <EPOCH_CHECKPOINT>
- INIT
Invoked in the constructor of
Trainer
. Model surgery (seemodule_surgery
) typically occurs here.
- FIT_START
Invoked at the beginning of each call to
Trainer.fit()
. Dataset transformations typically occur here.
- EPOCH_START
Start of an epoch.
- BATCH_START
Start of a batch.
- AFTER_DATALOADER
Immediately after the dataloader is called. Typically used for on-GPU dataloader transforms.
- BEFORE_TRAIN_BATCH
Before the forward-loss-backward computation for a training batch. When using gradient accumulation, this is still called only once.
- BEFORE_FORWARD
Before the call to
model.forward()
.
- AFTER_FORWARD
After the call to
model.forward()
.
- BEFORE_LOSS
Before the call to
model.loss()
.
- AFTER_LOSS
After the call to
model.loss()
.
- BEFORE_BACKWARD
Before the call to
loss.backward()
.
- AFTER_BACKWARD
After the call to
loss.backward()
.
- AFTER_TRAIN_BATCH
After the forward-loss-backward computation for a training batch. When using gradient accumulation, this event still fires only once.
- BATCH_END
End of a batch, which occurs after the optimizer step and any gradient scaling.
- BATCH_CHECKPOINT
After
Event.BATCH_END
and any batch-wise evaluation. Saving checkpoints at this event allows the checkpoint saver to use the results from any batch-wise evaluation to determine whether a checkpoint should be saved.
- EPOCH_END
End of an epoch.
- EPOCH_CHECKPOINT
After
Event.EPOCH_END
and any epoch-wise evaluation. Saving checkpoints at this event allows event allows the checkpoint saver to use the results from any epoch-wise evaluation to determine whether a checkpointshould be saved.
- EVAL_START
Start of evaluation through the validation dataset.
- EVAL_BATCH_START
Before the call to
model.validate(batch)
- EVAL_BEFORE_FORWARD
Before the call to
model.validate(batch)
- EVAL_AFTER_FORWARD
After the call to
model.validate(batch)
- EVAL_BATCH_END
After the call to
model.validate(batch)
- EVAL_END
End of evaluation through the validation dataset.