composer.algorithms.swa.swa#
Core code for Stochastic Weight Averaging.
Classes
Apply Stochastic Weight Averaging (Izmailov et al, 2018) |
- class composer.algorithms.swa.swa.SWA(swa_start='0.7dur', swa_end='0.97dur', schedule_swa_lr=False, anneal_strategy='linear', anneal_epochs=10, swa_lr=None)[source]#
Bases:
composer.core.algorithm.Algorithm
Apply Stochastic Weight Averaging (Izmailov et al, 2018)
Stochastic Weight Averaging (SWA) averages model weights sampled at different times near the end of training. This leads to better generalization than just using the final trained weights.
Because this algorithm needs to maintain both the current value of the weights and the average of all of the sampled weights, it doubles the modelโs memory consumption. Note that this does not mean that the total memory required doubles, however, since stored activations and the optimizer state are not doubled.
Uses PyTorchโs torch.optim.swa_util under the hood.
See the Method Card for more details.
Example
from composer.algorithms import SWA from composer.trainer import Trainer swa_algorithm = SWA( swa_start="6ep", swa_end="8ep" ) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="10ep", algorithms=[swa_algorithm], optimizers=[optimizer] )
- Parameters
swa_start (str, optional) โ The time string denoting the amount of training completed before stochastic weight averaging begins. Currently only units of duration (โdurโ) and epoch (โepโ) are supported. Defalt =
'0.7dur'
.swa_end (str, optional) โ The time string denoting the amount of training completed before the baseline (non-averaged) model is replaced with the stochastic weight averaged model. Itโs important to have at least one epoch of training after the baseline model is replaced by the SWA model so that the SWA model can have its buffers (most importantly its batch norm statistics) updated. If
swa_end
occurs during the final epoch of training (e.g.swa_end = 0.9dur
andmax_duration = "5ep"
, orswa_end = 1.0dur
), the SWA model will not have its buffers updated, which can negatively impact accuracy, so ensureswa_end
< \(\frac{N_{epochs}-1}{N_{epochs}}\). Currently only units of duration (โdurโ) and epoch (โepโ) are supported. Default ='0.97dur'
.schedule_swa_lr (bool, optional) โ Flag to determine whether to apply an SWA-specific LR schedule during the period in which SWA is active. Default =
False
.anneal_strategy (str, optional) โ SWA learning rate annealing schedule strategy. โlinearโ for linear annealing, โcosโ for cosine annealing. Default =
"linear"
.anneal_epochs (int, optional) โ Number of epochs over which to anneal SWA learning rate. Default =
10
.swa_lr (float, optional) โ The final learning rate to anneal towards with SWA LR scheduler. Set to
None
for no annealing.