Stochastic Weight Averaging
The above image is from an extensive PyTorch blogpost about SWA: https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
Tags: All
, Increased Accuracy
, Method
TL;DR
Stochastic Weight Averaging (SWA) maintains a running average of the weights towards the end of training. This leads to better generalization than conventional training.
Attribution
“Averaging Weights Leads to Wider Optima and Better Generalization” by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Presented at the 2018 Conference on Uncertainty in Artificial Intelligence.
Applicable Settings
Stochastic Weight Averaging is generally applicable across model architectures, tasks, and domains. It has been shown to improve performance in both vision tasks (e.g. ImageNet) as well as NLP tasks.
Hyperparameters
swa_start
: percent of training completed before stochastic weight averaging is applied. The default value is 0.8swa_lr
: The final learning rate to anneal towards
Example Effects
SWA leads to accuracy improvements of about 1-1.5% on Resnet50-ImageNet. From the original paper and subsequent work by the authors (see their repo here):
Model |
Baseline SGD |
SWA 5 epochs |
SWA 10 epochs |
---|---|---|---|
ResNet-50 |
76.15 |
76.83±0.01 |
76.97±0.05 |
ResNet-152 |
78.31 |
78.82±0.01 |
78.94±0.07 |
Note that the implementation in the original papers is slightly different than the current PyTorch implementation.
Implementation Details
The implementation is based off of the PyTorch implementation, which treats SWA as an optimizer. SWALR
is imported from torch.optim.swa_utils
. The current implementation first applies a cosine decay which then reaches a fixed learning rate value swa_lr
, and then begins maintaining a running average.
Considerations
As per the paper, the majority of training should be completed (e.g. 75%-80%) before the final SWA learning rate is applied.
Composability
Stochastic Weight Averaging composes well with other methods such as Mix Up and Label Smoothing on ImageNet.
There are studies in progress to see the effect of SWA on other image tasks and NLP.
Code
- class composer.algorithms.swa.SWA(swa_start: float = 0.8, anneal_epochs: int = 10, swa_lr: Optional[float] = None)[source]
Apply Stochastic Weight Averaging (Izmailov et al.)
Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.
Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the model’s memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.
- Parameters
swa_start – fraction of training completed before stochastic weight averaging is applied
swa_lr – the final learning rate used for weight averaging
Note that ‘anneal_epochs’ is not used in the current implementation
- apply(event: composer.core.event.Event, state: composer.core.state.State, logger: composer.core.logging.logger.Logger) None [source]
Apply SWA to weights towards the end of training
- match(event: composer.core.event.Event, state: composer.core.state.State) bool [source]
Run in Event.TRAINING_START, Event.TRAINING_END or if Event.EPOCH_END and epochs greater than or equal to swa_start * max_epochs
- Parameters
event (
Event
) – The current event.state (
State
) – The current state.
- Returns
bool – True if this algorithm should run now.