composer.core.state#
The state of the trainer.
Classes
The state of the trainer. |
- class composer.core.state.State(model, max_duration, train_dataloader, evaluators=[], grad_accum=1, precision=Precision.FP32, precision_context=<function _default_precision_factory.<locals>.null>, optimizers=None, scaler=None, algorithms=(), callbacks=(), steps_per_epoch=None)[source]#
Bases:
composer.core.serializable.Serializable
The state of the trainer.
Contains variables that the trainer tracks throughout the training loop. Note that the entire state is serialized when the trainer is checkpointed so that it can be used restore the trainer and continue training from a checkpoint. Algorithms are able to modify this object in-place.
Note
To support multi-GPU training,
State.model
may be wrapped inDistributedDataParallel
, and the dataloaders may be wrapped in a device-specific dataloader that handles moving tensors to device.Note
Schedulers
are wrapped inComposableScheduler
, which handles stepping either stepwise or epochwise, and also properly sets up learning rate warmups.- Parameters
model (types.Model, often ComposerModel) โ The model, typically as a subclass of
ComposerModel
.grad_accum (int) โ The number of gradient accumulation steps to use. The size of each microbatch is
train_batch_size / num_gpus / grad_accum
.train_dataloader (DataLoader, types.DataSpec, or dict) โ The
types.DataLoader
,types.DataSpec
, or dict oftypes.DataSpec
kwargs to used for training.evaluators (Evaluators) โ The
types.Evaluators
contain the evaluation datasets used for evaluation with specific metrics.max_duration (str or Time) โ The maximum duration to train for.
precision (str | Precision) โ The numerical precision to use for training. Should be one of
[fp32, amp]
.((precision (precision_context) โ Precision) -> ContextManager): Function to produce a context manager to mandate precision.
optimizers (types.Optimizers, optional) โ The optimizers being used to train the model. Multiple optimizers are not currently supported.
schedulers (types.Schedulers, optional) โ The learning rate schedulers, typically wrapped in
ComposableScheduler
.scaler (GradScaler, optional) โ The gradient scaler in use for mixed precision training.
algorithms (Sequence[Algorithm]) โ The algorithms used for training.
callbacks (Sequence[Callback]) โ The callbacks used for training.
profiler (Optional[Profiler]) โ The Composer profiler.
- batch#
The batch. This will be the entire batch during the
Event.AFTER_DATALOADER
, or a microbatch betweenEvent.BATCH_START
andEvent.BATCH_END
.- Type
types.Batch
- loss#
The most recently computed loss.
- Type
types.Tensors
- outputs#
The most recently computed output from the modelโs forward pass.
- Type
types.Tensors
- timer#
The timer that tracks training loop progress.
- Type
types.Timer
- serialized_attributes#
The list of attributes which will be serialized in a checkpoint.
- Type
List[str]
- property batch_dict#
The current batch, represented as a
BatchDict
.- Raises
TypeError โ If the current batch is not a
BatchDict
.- Type
BatchDict
- property batch_pair#
The current batch, represented as a
BatchPair
.- Raises
TypeError โ If the current batch is not a
BatchPair
.- Type
BatchPair
- property epoch#
The index of the current epoch.
- get_elapsed_duration()[source]#
Get the elapsed training duration.
- Returns
Time โ The elapsed duration, in
TimeUnit.DURATION
.
- load_model_state(state_dict, strict)[source]#
Loads the modelโs state from a state_dict.
- Parameters
state_dict (types.StateDict) โ object returned from call to
state_dict()
.strict (bool) โ whether the keys in the state_dict should perfectly match the keys in the model.
- load_state_dict(state, strict=False)[source]#
Loads the state.
- Parameters
state_dict (types.StateDict) โ object returned from call to
state_dict()
.
- property max_epochs#
The maximum number of epochs to train for.
- property precision#
The numerical precision to use for training.
Should be one of
[fp32, amp]
.
- property step#
The index of the current step/batch (measured globally).