composer.core.state#

The state of the trainer.

Classes

State

The state of the trainer.

class composer.core.state.State(model, max_duration, rank_zero_seed, train_dataloader, evaluators=[], grad_accum=1, precision=Precision.FP32, precision_context=<function _default_precision_factory.<locals>.null>, optimizers=None, scaler=None, algorithms=(), callbacks=(), steps_per_epoch=None)[source]#

Bases: composer.core.serializable.Serializable

The state of the trainer.

Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e., serialized_attributes) of state are serialized when the trainer is checkpointed so that it can be used restore the trainer and continue training from a checkpoint. algorithms are able to modify an instance of this class in-place.

Note

An instance of this class is automatically constructed by the Trainer constructor. A user need not instantiate this class.

Parameters
  • model (Model) โ€“ The model, typically as a subclass of ComposerModel.

  • rank_zero_seed (int) โ€“ The seed used on the rank zero process. It is assumed that each rankโ€™s seed is rank_zero_seed + dist.get_global_rank().

  • grad_accum (int) โ€“ The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes microbatch_size = train_batch_size / (num_devices * grad_accum).

  • train_dataloader (DataLoader, DataSpec, or dict) โ€“ The DataLoader, DataSpec, or dict of DataSpec kwargs to used for training.

  • evaluators (Evaluators) โ€“ The Evaluators contain the evaluation datasets used for evaluation with specific metrics.

  • max_duration (str or Time) โ€“ The maximum duration to train for.

  • precision (str | Precision) โ€“ The numerical precision to use for training. See Precision for the supported precisions.

  • precision_context (Callable[[Precision], ContextManager]) โ€“ Function to produce a context manager to mandate precision.

  • optimizers (Optimizers, optional) โ€“ The optimizers being used to train the model. Multiple optimizers are not currently supported.

  • schedulers (PyTorchScheduler | List[PyTorchScheduler] | Tuple[PyTorchScheduler, โ€ฆ], optional) โ€“ The learning rate scheduler (can also be a list or tuple of schedulers).

  • scaler (GradScaler, optional) โ€“ The gradient scaler in use for mixed precision training.

  • algorithms (Sequence[Algorithm]) โ€“ The algorithms used for training.

  • callbacks (Sequence[Callback]) โ€“ The callbacks used for training.

  • profiler (Optional[Profiler]) โ€“ The Composer profiler.

batch#

The batch. This will be the entire batch during the Event.AFTER_DATALOADER, or a microbatch between Event.BATCH_START and Event.BATCH_END.

Type

Batch

batch_num_samples#

The number of samples in the batch.

Type

int

batch_num_tokens#

The number of tokens in the batch.

Type

int

loss#

The most recently computed loss.

Type

Tensors

outputs#

The most recently computed output from the modelโ€™s forward pass.

Type

Tensors

timer#

The timer that tracks training loop progress.

Type

Timer

serialized_attributes#

The names of the attribute which are serialized in a checkpoint.

By default, the following attributes are serialized:

Attribute

Description

model

The model under training.

optimizers

The optimizers being used to train the model.

schedulers

The learning rate schedulers.

algorithms

The algorithms used for training.

callbacks

The callbacks used for training.

scaler

The gradient scaler in use for mixed precision training.

timer

The timer that tracks training loop progress.

is_model_ddp

Whether the model is an instance of DistributedDataParallel.

rank_zero_seed

The seed of the rank zero process.

Type

List[str]

property batch_dict#

The current batch, represented as a BatchDict.

Raises

TypeError โ€“ If the current batch is not a BatchDict.

Type

BatchDict

property batch_pair#

The current batch, represented as a BatchPair.

Raises

TypeError โ€“ If the current batch is not a BatchPair.

Type

BatchPair

property deepspeed_model#

Cast model to DeepSpeedEngine.

get_elapsed_duration()[source]#

Get the elapsed training duration.

Returns

Time โ€“ The elapsed duration, in TimeUnit.DURATION. Time(0.0, TimeUnit.DURATION) represents the beginning of training and Time(1.0, TimeUnit.DURATION) represents a completed training process.

property is_model_ddp#

Whether model is an instance of a DistributedDataParallel.

property is_model_deepspeed#

Whether model is an instance of a DeepSpeedEngine.

load_model_state(state_dict, strict)[source]#

Loads the modelโ€™s state from a state_dict.

Parameters
  • state_dict (types.StateDict) โ€“ The state dict, generated from a previous call to state_dict().

  • strict (bool) โ€“ Whether the keys (i.e., model parameter names) in the model state dict should perfectly match the keys in the model instance.

load_state_dict(state, strict=False)[source]#

Loads the state.

Parameters
  • state (types.StateDict) โ€“ object returned from call to state_dict().

  • strict (bool) โ€“ whether the keys in the state["model"] should perfectly match the keys in the self.model. Defaults to False.

property max_duration#

The maximum training duration.

property precision#

The numerical precision to use for training.

See Precision for the supported precisions.

property seed#

The seed for the current rank.

state_dict()[source]#

Returns the state as a dict.

property steps_per_epoch#

The maximum number of steps (batches) per epoch.

Type

int