composer.State
The State
object is available for algorithms to modify during
Algorithm.apply()
, and captures the state of the trainer.
A summary of available attributes and properties is given below:
Attribute |
Type |
Description |
---|---|---|
Training arguments |
||
|
|
Model, typically as a subclass of |
|
|
Global batch size for training |
|
|
Batch size for evaluation |
|
|
Gradient accumulation steps. The size of each microbatch would be |
|
|
Maximum number of epochs |
|
|
Precision, one of |
|
|
Called with the precision to return a contextmanager. |
Timing Information |
||
|
|
The current epoch |
|
|
The current step (in terms of optimization steps) |
|
|
Index of the batch in the current epoch. Not mutable. |
|
|
Number of optimization steps per epoch. Not mutable. |
Training Loop Tensors |
||
|
|
Batch returned by the dataloader. We currently support a |
|
|
Helper |
|
|
Helper |
|
|
last computed loss |
|
|
Batch size returned from the dataloader. This can be different from the current size of |
|
|
Output of the model’s forward pass. |
Optimizers |
||
|
|
Optimizers. Multiple optimizers are not currently supported. |
|
|
LR schedulers, wrapped in |
|
|
Gradient scaler for mixed precision. |
Dataloaders |
||
|
|
Dataloader for training. |
|
|
Dataloader for evaluation. |
Algorithms |
||
|
|
List of algorithms |
|
|
List of callbacks, including loggers |
Note
To support multi-GPU training, state.model
may be wrapped in DistributedDataParallel
, and the dataloaders may be wrapped in a device-specific dataloader that handles moving tensors to device.
Note
Schedulers
are wrapped in ComposableScheduler
, which handles stepping either stepwise or epochwise, and also properly sets up learning rate warmups.
- class composer.State(model: types.Model, train_batch_size: int, eval_batch_size: int, grad_accum: int, max_epochs: int, precision: Union[str, types.Precision] = <property object>, precision_context: Callable[[Union[str, Precision]], ContextManager] = <factory>, epoch: int = 0, step: int = 0, loss: types.Tensors = <factory>, last_batch_size: int = 0, batch: types.Batch = <factory>, outputs: types.Tensors = <factory>, optimizers: Optional[types.Optimizers] = None, schedulers: Optional[types.Schedulers] = None, scaler: Optional[types.Scaler] = None, train_dataloader: Optional[types.DataLoader] = None, eval_dataloader: Optional[types.DataLoader] = None, algorithms: Sequence[Algorithm] = (), callbacks: Sequence[Callback] = ())[source]
The class used to store the state of the trainer.
Contains variables that the trainer tracks throughout the training loop. Note that the entire state is serialized when the trainer is checkpointed so that it can be used restore the trainer and continue training from a checkpoint. Algorithms are able to modify this object in-place.
- grad_accum
The number of gradient accumulation steps to use. The size of each microbatch is
train_batch_size / num_gpus / grad_accum
.- Type
- precision
The numerical precision to use for training. Should be one of
[fp32, amp]
.
- precision_context ((precision
Precision) -> ContextManager): Function to produce a context manager to mandate precision.
- last_batch_size
The size of the batch last returned from the dataloader. This can be different from the current size of
batch
if algorithms have modified thebatch
.- Type
- optimizers
The optimizers being used to train the model. Multiple optimizers are not currently supported.
- Type
- schedulers
The learning rate schedulers, typically wrapped in
ComposableScheduler
.- Type
- scaler
The gradient scaler in use for mixed precision training.
- Type
GradScaler, optional
- train_dataloader
The dataloader used for training.
- Type
- eval_dataloader
The dataloader used for evaluation.
- Type
- property batch_pair: Union[Tuple[Union[Tensor, Tuple[Tensor, ...], List[Tensor]], Union[Tensor, Tuple[Tensor, ...], List[Tensor]]], List[Tensor]]
The current batch, represented as a
BatchPair
.
- load_state_dict(state: Dict[str, Any])[source]
Loads the state.
- Parameters
state_dict (StateDict) – object returned from call to
state_dict()
.