composer.core.state#

The state of the trainer.

Classes

State

The state of the trainer.

class composer.core.state.State(model, max_duration, train_dataloader, evaluators=[], grad_accum=1, precision=Precision.FP32, precision_context=<function _default_precision_factory.<locals>.null>, optimizers=None, scaler=None, algorithms=(), callbacks=(), steps_per_epoch=None)[source]#

Bases: composer.core.serializable.Serializable

The state of the trainer.

Contains variables that the trainer tracks throughout the training loop. Note that the entire state is serialized when the trainer is checkpointed so that it can be used restore the trainer and continue training from a checkpoint. Algorithms are able to modify this object in-place.

Note

To support multi-GPU training, State.model may be wrapped in DistributedDataParallel, and the dataloaders may be wrapped in a device-specific dataloader that handles moving tensors to device.

Note

Schedulers are wrapped in ComposableScheduler, which handles stepping either stepwise or epochwise, and also properly sets up learning rate warmups.

Parameters
  • model (types.Model, often ComposerModel) โ€“ The model, typically as a subclass of ComposerModel.

  • grad_accum (int) โ€“ The number of gradient accumulation steps to use. The size of each microbatch is train_batch_size / num_gpus / grad_accum.

  • train_dataloader (DataLoader, types.DataSpec, or dict) โ€“ The types.DataLoader, types.DataSpec, or dict of types.DataSpec kwargs to used for training.

  • evaluators (Evaluators) โ€“ The types.Evaluators contain the evaluation datasets used for evaluation with specific metrics.

  • max_duration (str or Time) โ€“ The maximum duration to train for.

  • precision (str | Precision) โ€“ The numerical precision to use for training. Should be one of [fp32, amp].

  • ((precision (precision_context) โ€“ Precision) -> ContextManager): Function to produce a context manager to mandate precision.

  • optimizers (types.Optimizers, optional) โ€“ The optimizers being used to train the model. Multiple optimizers are not currently supported.

  • schedulers (types.Schedulers, optional) โ€“ The learning rate schedulers, typically wrapped in ComposableScheduler.

  • scaler (GradScaler, optional) โ€“ The gradient scaler in use for mixed precision training.

  • algorithms (Sequence[Algorithm]) โ€“ The algorithms used for training.

  • callbacks (Sequence[Callback]) โ€“ The callbacks used for training.

  • profiler (Optional[Profiler]) โ€“ The Composer profiler.

batch#

The batch. This will be the entire batch during the Event.AFTER_DATALOADER, or a microbatch between Event.BATCH_START and Event.BATCH_END.

Type

types.Batch

batch_num_samples#

The number of samples in the batch.

Type

int

batch_num_tokens#

The number of tokens in the batch.

Type

int

loss#

The most recently computed loss.

Type

types.Tensors

outputs#

The most recently computed output from the modelโ€™s forward pass.

Type

types.Tensors

timer#

The timer that tracks training loop progress.

Type

types.Timer

serialized_attributes#

The list of attributes which will be serialized in a checkpoint.

Type

List[str]

property batch_dict#

The current batch, represented as a BatchDict.

Raises

TypeError โ€“ If the current batch is not a BatchDict.

Type

BatchDict

property batch_idx#

batch_idx is the index of the batch in the current epoch.

Type

int

property batch_pair#

The current batch, represented as a BatchPair.

Raises

TypeError โ€“ If the current batch is not a BatchPair.

Type

BatchPair

property epoch#

The index of the current epoch.

get_elapsed_duration()[source]#

Get the elapsed training duration.

Returns

Time โ€“ The elapsed duration, in TimeUnit.DURATION.

load_model_state(state_dict, strict)[source]#

Loads the modelโ€™s state from a state_dict.

Parameters
  • state_dict (types.StateDict) โ€“ object returned from call to state_dict().

  • strict (bool) โ€“ whether the keys in the state_dict should perfectly match the keys in the model.

load_state_dict(state, strict=False)[source]#

Loads the state.

Parameters

state_dict (types.StateDict) โ€“ object returned from call to state_dict().

property max_epochs#

The maximum number of epochs to train for.

property precision#

The numerical precision to use for training.

Should be one of [fp32, amp].

state_dict()[source]#

Returns the state as a dict.

property step#

The index of the current step/batch (measured globally).

property steps_per_epoch#

The maximum number of steps (batches) per epoch.

Type

int