Stochastic Weight Averaging

Untitled

The above image is from an extensive PyTorch blogpost about SWA: https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/

Tags: All, Increased Accuracy, Method

TL;DR

Stochastic Weight Averaging (SWA) maintains a running average of the weights towards the end of training. This leads to better generalization than conventional training.

Attribution

“Averaging Weights Leads to Wider Optima and Better Generalization” by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Presented at the 2018 Conference on Uncertainty in Artificial Intelligence.

Applicable Settings

Stochastic Weight Averaging is generally applicable across model architectures, tasks, and domains. It has been shown to improve performance in both vision tasks (e.g. ImageNet) as well as NLP tasks.

Hyperparameters

  • swa_start: percent of training completed before stochastic weight averaging is applied. The default value is 0.8

  • swa_lr: The final learning rate to anneal towards

Example Effects

SWA leads to accuracy improvements of about 1-1.5% on Resnet50-ImageNet. From the original paper and subsequent work by the authors (see their repo here):

Model

Baseline SGD

SWA 5 epochs

SWA 10 epochs

ResNet-50

76.15

76.83±0.01

76.97±0.05

ResNet-152

78.31

78.82±0.01

78.94±0.07

Note that the implementation in the original papers is slightly different than the current PyTorch implementation.

Implementation Details

The implementation is based off of the PyTorch implementation, which treats SWA as an optimizer. SWALR is imported from torch.optim.swa_utils. The current implementation first applies a cosine decay which then reaches a fixed learning rate value swa_lr, and then begins maintaining a running average.

Considerations

As per the paper, the majority of training should be completed (e.g. 75%-80%) before the final SWA learning rate is applied.

Composability

Stochastic Weight Averaging composes well with other methods such as Mix Up and Label Smoothing on ImageNet.

There are studies in progress to see the effect of SWA on other image tasks and NLP.


Code

class composer.algorithms.swa.SWA(swa_start=0.8, anneal_epochs=10, swa_lr=None)[source]

Apply Stochastic Weight Averaging

Stochastic Weight Averaging (SWA) averages model weights sampled towards the end of training. This leads to better generalization than conventional training.

See Averaging Weights Leads to Wider Optima and Better Generalization <https://arxiv.org/abs/1803.05407>.

Parameters
  • swa_start (float) – fraction of training completed before stochastic weight averaging is applied

  • anneal_epochs (int) – The final learning rate to anneal towards

  • swa_lr (float) – fraction of minibatch to select and keep for gradient computation

apply(event, state, logger)[source]

Applies the algorithm to make an in-place change to the State

Can optionally return an exit code to be stored in a Trace.

Parameters
  • event (Event) – The current event.

  • state (State) – The current state.

  • logger (Logger) – A logger to use for logging algorithm-specific metrics.

Returns

int or None – exit code that is stored in Trace and made accessible for debugging.

Return type

None

match(event, state)[source]

Determines whether this algorithm should run, given the current Event and State.

Examples:

To only run on a specific event:

>>> return event == Event.BEFORE_LOSS

Switching based on state attributes:

>>> return state.epoch > 30 && state.world_size == 1

See State for accessible attributes.

Parameters
  • event (Event) – The current event.

  • state (State) – The current state.

Returns

bool – True if this algorithm should run now.

Return type

bool