๐Ÿงฉ Stochastic Weight Averaging#

Untitled

The above image is from an extensive PyTorch blogpost about SWA: https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/

Tags: All, Increased Accuracy, Method

TL;DR#

Stochastic Weight Averaging (SWA) maintains a running average of the weights towards the end of training. This leads to better generalization than conventional training.

Attribution#

โ€œAveraging Weights Leads to Wider Optima and Better Generalizationโ€ by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, and Andrew Gordon Wilson. Presented at the 2018 Conference on Uncertainty in Artificial Intelligence.

Applicable Settings#

Stochastic Weight Averaging is generally applicable across model architectures, tasks, and domains. It has been shown to improve performance in both vision tasks (e.g. ImageNet) as well as NLP tasks.

Hyperparameters#

  • swa_start: percent of training completed before stochastic weight averaging is applied. The default value is 0.8

  • swa_lr: The final learning rate to anneal towards

Example Effects#

SWA leads to accuracy improvements of about 1-1.5% for ResNet50 on ImageNet. From the original paper and subsequent work by the authors (see their repo here):

Model

Baseline SGD

SWA 5 epochs

SWA 10 epochs

ResNet-50

76.15

76.83ยฑ0.01

76.97ยฑ0.05

ResNet-152

78.31

78.82ยฑ0.01

78.94ยฑ0.07

Note that the implementation in the original papers is slightly different than the current PyTorch implementation.

Implementation Details#

Our implementation is based off of the PyTorch implementation, which treats SWA as an optimizer. SWALR is imported from torch.optim.swa_utils. The current implementation first applies a cosine decay which reaches a fixed learning rate value, swa_lr, then begins maintaining a running average.

Considerations#

As per the paper, the majority of training should be completed (e.g. 75%-80%) before the final SWA learning rate is applied.

Composability#

Stochastic Weight Averaging composes well with other methods such as Mix Up and Label Smoothing on ImageNet.

There are studies in progress to see the effect of SWA on other image tasks and NLP.


Code#

class composer.algorithms.swa.SWA(swa_start='0.7dur', swa_end='0.97dur', update_interval='1ep', schedule_swa_lr=False, anneal_strategy='linear', anneal_steps=10, swa_lr=None)[source]

Applies Stochastic Weight Averaging (Izmailov et al, 2018).

Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.

Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the modelโ€™s memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.

Note

The AveragedModel is currently stored on the CPU device, which may cause slow training if the model weights are large.

Uses PyTorchโ€™s torch.optim.swa_util under the hood.

See the Method Card for more details.

Example

from composer.algorithms import SWA
from composer.trainer import Trainer

swa_algorithm = SWA(
    swa_start="6ep",
    swa_end="8ep"
)
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration="10ep",
    algorithms=[swa_algorithm],
    optimizers=[optimizer]
)
Parameters
  • swa_start (str, optional) โ€“ The time string denoting the amount of training completed before stochastic weight averaging begins. Currently only units of duration (โ€˜durโ€™) and epoch (โ€˜epโ€™) are supported. Default: '0.7dur'.

  • swa_end (str, optional) โ€“ The time string denoting the amount of training completed before the baseline (non-averaged) model is replaced with the stochastic weight averaged model. Itโ€™s important to have at least one epoch of training after the baseline model is replaced by the SWA model so that the SWA model can have its buffers (most importantly its batch norm statistics) updated. If swa_end occurs during the final epoch of training (e.g. swa_end = 0.9dur and max_duration = "5ep", or swa_end = 1.0dur), the SWA model will not have its buffers updated, which can negatively impact accuracy, so ensure swa_end < \(\frac{N_{epochs}-1}{N_{epochs}}\). Currently only units of duration (โ€˜durโ€™) and epoch (โ€˜epโ€™) are supported. Default: '0.97dur'.

  • update_interval (str, optional) โ€“ Time string denoting how often the averaged model is updated. For example, '1ep' means the averaged model will be updated once per epoch and '5ba' means the averaged model will be updated every 5 batches. Note that for single-epoch training runs (e.g. many NLP training runs), update_interval must be specified in units of 'ba', otherwise SWA wonโ€™t happen. Also note that very small update intervals (e.g. "1ba") can substantially slow down training. Default: '1ep'.

  • schedule_swa_lr (bool, optional) โ€“ Flag to determine whether to apply an SWA-specific LR schedule during the period in which SWA is active. Default: False.

  • anneal_strategy (str, optional) โ€“ SWA learning rate annealing schedule strategy. "linear" for linear annealing, "cos" for cosine annealing. Default: "linear".

  • anneal_steps (int, optional) โ€“ Number of SWA model updates over which to anneal SWA learning rate. Note that updates are determined by the update_interval argument. For example, if anneal_steps = 10 and update_interval = '1ep', then the SWA LR will be annealed once per epoch for 10 epochs; if anneal_steps = 20 and update_interval = '8ba', then the SWA LR will be annealed once every 8 batches over the course of 160 batches (20 steps * 8 batches/step). Default: 10.

  • swa_lr (float, optional) โ€“ The final learning rate to anneal towards with the SWA LR scheduler. Set to None for no annealing. Default: None.