composer.core.state#
The state of the trainer.
Classes
The state of the trainer. |
- class composer.core.state.State(model, max_duration, rank_zero_seed, train_dataloader, evaluators=[], grad_accum=1, precision=Precision.FP32, precision_context=<function _default_precision_factory.<locals>.null>, optimizers=None, scaler=None, algorithms=(), callbacks=(), steps_per_epoch=None)[source]#
Bases:
composer.core.serializable.Serializable
The state of the trainer.
Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e.,
serialized_attributes
) of state are serialized when the trainer is checkpointed so that it can be used restore the trainer and continue training from a checkpoint.algorithms
are able to modify an instance of this class in-place.Note
An instance of this class is automatically constructed by the
Trainer
constructor. A user need not instantiate this class.- Parameters
model (
Model
) โ The model, typically as a subclass ofComposerModel
.rank_zero_seed (int) โ The seed used on the rank zero process. It is assumed that each rankโs seed is
rank_zero_seed + dist.get_global_rank()
.grad_accum (int) โ The number of gradient accumulation steps to use. With this argument, micro batch size for each device becomes
microbatch_size = train_batch_size / (num_devices * grad_accum)
.train_dataloader (DataLoader, DataSpec, or dict) โ The
DataLoader
,DataSpec
, or dict ofDataSpec
kwargs to used for training.evaluators (
Evaluators
) โ TheEvaluators
contain the evaluation datasets used for evaluation with specific metrics.max_duration (str or Time) โ The maximum duration to train for.
precision (str | Precision) โ The numerical precision to use for training. See
Precision
for the supported precisions.precision_context (Callable[[Precision], ContextManager]) โ Function to produce a context manager to mandate precision.
optimizers (
Optimizers
, optional) โ The optimizers being used to train the model. Multiple optimizers are not currently supported.schedulers (
PyTorchScheduler
| List[PyTorchScheduler
] | Tuple[PyTorchScheduler
, โฆ], optional) โ The learning rate scheduler (can also be a list or tuple of schedulers).scaler (GradScaler, optional) โ The gradient scaler in use for mixed precision training.
algorithms (Sequence[Algorithm]) โ The algorithms used for training.
callbacks (Sequence[Callback]) โ The callbacks used for training.
profiler (Optional[Profiler]) โ The Composer profiler.
- batch#
The batch. This will be the entire batch during the
Event.AFTER_DATALOADER
, or a microbatch betweenEvent.BATCH_START
andEvent.BATCH_END
.- Type
- serialized_attributes#
The names of the attribute which are serialized in a checkpoint.
By default, the following attributes are serialized:
Attribute
Description
model
The model under training.
optimizers
The optimizers being used to train the model.
schedulers
The learning rate schedulers.
algorithms
The algorithms used for training.
callbacks
The callbacks used for training.
scaler
The gradient scaler in use for mixed precision training.
timer
The timer that tracks training loop progress.
is_model_ddp
Whether the model is an instance of
DistributedDataParallel
.rank_zero_seed
The seed of the rank zero process.
- Type
List[str]
- property deepspeed_model#
Cast
model
toDeepSpeedEngine
.
- get_elapsed_duration()[source]#
Get the elapsed training duration.
- Returns
Time โ The elapsed duration, in
TimeUnit.DURATION
.Time(0.0, TimeUnit.DURATION)
represents the beginning of training andTime(1.0, TimeUnit.DURATION)
represents a completed training process.
- property is_model_ddp#
Whether
model
is an instance of aDistributedDataParallel
.
- property is_model_deepspeed#
Whether
model
is an instance of aDeepSpeedEngine
.
- load_model_state(state_dict, strict)[source]#
Loads the modelโs state from a state_dict.
- Parameters
state_dict (types.StateDict) โ The state dict, generated from a previous call to
state_dict()
.strict (bool) โ Whether the keys (i.e., model parameter names) in the model state dict should perfectly match the keys in the model instance.
- load_state_dict(state, strict=False)[source]#
Loads the state.
- Parameters
state (types.StateDict) โ object returned from call to
state_dict()
.strict (bool) โ whether the keys in the
state["model"]
should perfectly match the keys in theself.model
. Defaults to False.
- property max_duration#
The maximum training duration.
- property precision#
The numerical precision to use for training.
See
Precision
for the supported precisions.
- property seed#
The seed for the current rank.