composer.algorithms.swa.swa#

Core code for Stochastic Weight Averaging.

Classes

SWA

Apply Stochastic Weight Averaging (Izmailov et al, 2018)

class composer.algorithms.swa.swa.SWA(swa_start='0.7dur', swa_end='0.97dur', schedule_swa_lr=False, anneal_strategy='linear', anneal_epochs=10, swa_lr=None)[source]#

Bases: composer.core.algorithm.Algorithm

Apply Stochastic Weight Averaging (Izmailov et al, 2018)

Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.

Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the modelโ€™s memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.

Uses PyTorchโ€™s torch.optim.swa_util under the hood.

See the Method Card for more details.

Example

from composer.algorithms import SWA
from composer.trainer import Trainer
swa_algorithm = SWA(
    swa_start="6ep",
    swa_end="8ep"
)
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration="10ep",
    algorithms=[swa_algorithm],
    optimizers=[optimizer]
)
Parameters
  • swa_start (str, optional) โ€“ The time string denoting the amount of training completed before stochastic weight averaging begins. Currently only units of duration (โ€˜durโ€™) and epoch (โ€˜epโ€™) are supported. Defalt = '0.7dur'.

  • swa_end (str, optional) โ€“ The time string denoting the amount of training completed before the baseline (non-averaged) model is replaced with the stochastic weight averaged model. Itโ€™s important to have at least one epoch of training after the baseline model is replaced by the SWA model so that the SWA model can have its buffers (most importantly its batch norm statistics) updated. If swa_end occurs during the final epoch of training (e.g. swa_end = 0.9dur and max_duration = "5ep", or swa_end = 1.0dur), the SWA model will not have its buffers updated, which can negatively impact accuracy, so ensure swa_end < \(\frac{N_{epochs}-1}{N_{epochs}}\). Currently only units of duration (โ€˜durโ€™) and epoch (โ€˜epโ€™) are supported. Default = '0.97dur'.

  • schedule_swa_lr (bool, optional) โ€“ Flag to determine whether to apply an SWA-specific LR schedule during the period in which SWA is active. Default = False.

  • anneal_strategy (str, optional) โ€“ SWA learning rate annealing schedule strategy. โ€œlinearโ€ for linear annealing, โ€œcosโ€ for cosine annealing. Default = "linear".

  • anneal_epochs (int, optional) โ€“ Number of epochs over which to anneal SWA learning rate. Default = 10.

  • swa_lr (float, optional) โ€“ The final learning rate to anneal towards with SWA LR scheduler. Set to None for no annealing.