Stochastic Weight Averaging

Untitled

The above image is from an extensive PyTorch blogpost about SWA: https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/

Tags: All, Increased Accuracy, Method

TL;DR

Stochastic Weight Averaging (SWA) maintains a running average of the weights towards the end of training. This leads to better generalization than conventional training.

Attribution

“Averaging Weights Leads to Wider Optima and Better Generalization” by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Presented at the 2018 Conference on Uncertainty in Artificial Intelligence.

Applicable Settings

Stochastic Weight Averaging is generally applicable across model architectures, tasks, and domains. It has been shown to improve performance in both vision tasks (e.g. ImageNet) as well as NLP tasks.

Hyperparameters

  • swa_start: percent of training completed before stochastic weight averaging is applied. The default value is 0.8

  • swa_lr: The final learning rate to anneal towards

Example Effects

SWA leads to accuracy improvements of about 1-1.5% on Resnet50-ImageNet. From the original paper and subsequent work by the authors (see their repo here):

Model

Baseline SGD

SWA 5 epochs

SWA 10 epochs

ResNet-50

76.15

76.83±0.01

76.97±0.05

ResNet-152

78.31

78.82±0.01

78.94±0.07

Note that the implementation in the original papers is slightly different than the current PyTorch implementation.

Implementation Details

The implementation is based off of the PyTorch implementation, which treats SWA as an optimizer. SWALR is imported from torch.optim.swa_utils. The current implementation first applies a cosine decay which then reaches a fixed learning rate value swa_lr, and then begins maintaining a running average.

Considerations

As per the paper, the majority of training should be completed (e.g. 75%-80%) before the final SWA learning rate is applied.

Composability

Stochastic Weight Averaging composes well with other methods such as Mix Up and Label Smoothing on ImageNet.

There are studies in progress to see the effect of SWA on other image tasks and NLP.


Code

class composer.algorithms.swa.SWA(swa_start: float = 0.8, anneal_epochs: int = 10, swa_lr: Optional[float] = None)[source]

Apply Stochastic Weight Averaging (Izmailov et al.)

Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.

Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the model’s memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.

Parameters
  • swa_start – fraction of training completed before stochastic weight averaging is applied

  • swa_lr – the final learning rate used for weight averaging

Note that ‘anneal_epochs’ is not used in the current implementation

apply(event: composer.core.event.Event, state: composer.core.state.State, logger: composer.core.logging.logger.Logger) None[source]

Apply SWA to weights towards the end of training

Parameters
  • event (Event) – the current event

  • state (State) – the current trainer state

  • logger (Logger) – the training logger

match(event: composer.core.event.Event, state: composer.core.state.State) bool[source]

Run in Event.TRAINING_START, Event.TRAINING_END or if Event.EPOCH_END and epochs greater than or equal to swa_start * max_epochs

Parameters
  • event (Event) – The current event.

  • state (State) – The current state.

Returns

bool – True if this algorithm should run now.