composer.core.callback#

Base module for callbacks.

Classes

Callback

Base class for callbacks.

class composer.core.callback.Callback[source]#

Bases: composer.core.serializable.Serializable, abc.ABC

Base class for callbacks.

Callbacks provide hooks that can run at each training loop Event. A callback is similar to an Algorithm in that they are run on specific events. Callbacks differ from Algorithm in that they do not modify the training of the model. By convention, callbacks should not modify the State. They are typically used to for non-essential recording functions such as logging or timing.

Callbacks can be implemented in two ways:

  1. Override the individual methods named for each Event.

    For example,

    >>> class MyCallback(Callback):
    ...     def epoch_start(self, state: State, logger: Logger):
    ...         print(f'Epoch: {int(state.timer.epoch)}')
    >>> # construct trainer object with your callback
    >>> trainer = Trainer(
    ...     model=model,
    ...     train_dataloader=train_dataloader,
    ...     eval_dataloader=eval_dataloader,
    ...     optimizers=optimizer,
    ...     max_duration="1ep",
    ...     callbacks=[MyCallback()],
    ... )
    >>> # trainer will run MyCallback whenever the EPOCH_START
    >>> # is triggered, like this:
    >>> _ = trainer.engine.run_event(Event.EPOCH_START)
    Epoch: 0
    
  2. Override run_event() if you want a single method to handle all events. If this method is overridden, then the individual methods corresponding to each event name (such as epoch_start()) will no longer be automatically invoked. For example, if you override run_event() then epoch_start() will not be called on the EPOCH_START event, batch_start() will not be called on the BATCH_START etc. However, you can invoke epoch_start(), batch_start() etc. in your overriding implementation of run_event().

    For example,

    >>> class MyCallback(Callback):
    ...     def run_event(self, event: Event, state: State, logger: Logger):
    ...         if event == Event.EPOCH_START:
    ...             print(f'Epoch: {int(state.timer.epoch)}')
    >>> # construct trainer object with your callback
    >>> trainer = Trainer(
    ...     model=model,
    ...     train_dataloader=train_dataloader,
    ...     eval_dataloader=eval_dataloader,
    ...     optimizers=optimizer,
    ...     max_duration="1ep",
    ...     callbacks=[MyCallback()],
    ... )
    >>> # trainer will run MyCallback whenever the EPOCH_START
    >>> # is triggered, like this:
    >>> _ = trainer.engine.run_event(Event.EPOCH_START)
    Epoch: 0
    
after_backward(state, logger)[source]#

Called on the AFTER_BACKWARD event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

after_dataloader(state, logger)[source]#

Called on the AFTER_DATALOADER event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

after_forward(state, logger)[source]#

Called on the AFTER_FORWARD event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

after_loss(state, logger)[source]#

Called on the AFTER_LOSS event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

after_train_batch(state, logger)[source]#

Called on the AFTER_TRAIN_BATCH event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

batch_checkpoint(state, logger)[source]#

Called on the BATCH_CHECKPOINT event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

batch_end(state, logger)[source]#

Called on the BATCH_END event.

Note

The following Timer member variables are incremented immediately before the BATCH_END event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

batch_start(state, logger)[source]#

Called on the BATCH_START event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

before_backward(state, logger)[source]#

Called on the BEFORE_BACKWARD event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

before_forward(state, logger)[source]#

Called on the BEFORE_FORWARD event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

before_loss(state, logger)[source]#

Called on the BEFORE_LOSS event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

before_train_batch(state, logger)[source]#

Called on the BEFORE_TRAIN_BATCH event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

close(state, logger)[source]#

Called whenever the trainer finishes training, even when there is an exception.

It should be used for clean up tasks such as flushing I/O streams and/or closing any files that may have been opened during the INIT event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

epoch_checkpoint(state, logger)[source]#

Called on the EPOCH_CHECKPOINT event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

epoch_end(state, logger)[source]#

Called on the EPOCH_END event.

Note

Timer member variable Timer.epoch is incremented immediately before EPOCH_END.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

epoch_start(state, logger)[source]#

Called on the EPOCH_START event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

eval_after_forward(state, logger)[source]#

Called on the EVAL_AFTER_FORWARD event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

eval_batch_end(state, logger)[source]#

Called on the EVAL_BATCH_END event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

eval_batch_start(state, logger)[source]#

Called on the EVAL_BATCH_START event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

eval_before_forward(state, logger)[source]#

Called on the EVAL_BATCH_FORWARD event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

eval_end(state, logger)[source]#

Called on the EVAL_END event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

eval_start(state, logger)[source]#

Called on the EVAL_START event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

fit_end(state, logger)[source]#

Called on the FIT_END event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

fit_start(state, logger)[source]#

Called on the FIT_START event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

init(state, logger)[source]#

Called on the INIT event.

Parameters
  • state (State) โ€“ The global state.

  • logger (Logger) โ€“ The logger.

post_close()[source]#

This hook is called after close() has been invoked for each callback. Very few callbacks should need to implement post_close().

This callback can be used to back up any data that may have been written by other callbacks during close().

run_event(event, state, logger)[source]#

This method is called by the engine on each event.

Parameters
  • event (Event) โ€“ The event.

  • state (State) โ€“ The state.

  • logger (Logger) โ€“ The logger.