Stochastic Weight Averaging
The above image is from an extensive PyTorch blogpost about SWA: https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
Tags: All
, Increased Accuracy
, Method
TL;DR
Stochastic Weight Averaging (SWA) maintains a running average of the weights towards the end of training. This leads to better generalization than conventional training.
Attribution
“Averaging Weights Leads to Wider Optima and Better Generalization” by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Presented at the 2018 Conference on Uncertainty in Artificial Intelligence.
Applicable Settings
Stochastic Weight Averaging is generally applicable across model architectures, tasks, and domains. It has been shown to improve performance in both vision tasks (e.g. ImageNet) as well as NLP tasks.
Hyperparameters
swa_start
: percent of training completed before stochastic weight averaging is applied. The default value is 0.8swa_lr
: The final learning rate to anneal towards
Example Effects
SWA leads to accuracy improvements of about 1-1.5% on Resnet50-ImageNet. From the original paper and subsequent work by the authors (see their repo here):
Model |
Baseline SGD |
SWA 5 epochs |
SWA 10 epochs |
---|---|---|---|
ResNet-50 |
76.15 |
76.83±0.01 |
76.97±0.05 |
ResNet-152 |
78.31 |
78.82±0.01 |
78.94±0.07 |
Note that the implementation in the original papers is slightly different than the current PyTorch implementation.
Implementation Details
The implementation is based off of the PyTorch implementation, which treats SWA as an optimizer. SWALR
is imported from torch.optim.swa_utils
. The current implementation first applies a cosine decay which then reaches a fixed learning rate value swa_lr
, and then begins maintaining a running average.
Considerations
As per the paper, the majority of training should be completed (e.g. 75%-80%) before the final SWA learning rate is applied.
Composability
Stochastic Weight Averaging composes well with other methods such as Mix Up and Label Smoothing on ImageNet.
There are studies in progress to see the effect of SWA on other image tasks and NLP.
Code
- class composer.algorithms.swa.SWA(swa_start=0.8, anneal_epochs=10, swa_lr=None)[source]
Apply Stochastic Weight Averaging
Stochastic Weight Averaging (SWA) averages model weights sampled towards the end of training. This leads to better generalization than conventional training.
See Averaging Weights Leads to Wider Optima and Better Generalization <https://arxiv.org/abs/1803.05407>.
- Parameters
- apply(event, state, logger)[source]
Applies the algorithm to make an in-place change to the State
Can optionally return an exit code to be stored in a
Trace
.- Parameters
event (
Event
) – The current event.state (
State
) – The current state.logger (
Logger
) – A logger to use for logging algorithm-specific metrics.
- Returns
int or None – exit code that is stored in
Trace
and made accessible for debugging.- Return type
- match(event, state)[source]
Determines whether this algorithm should run, given the current
Event
andState
.Examples:
To only run on a specific event:
>>> return event == Event.BEFORE_LOSS
Switching based on state attributes:
>>> return state.epoch > 30 && state.world_size == 1
See
State
for accessible attributes.- Parameters
event (
Event
) – The current event.state (
State
) – The current state.
- Returns
bool – True if this algorithm should run now.
- Return type