composer.algorithms.ema.ema#

Core Exponential Moving Average (EMA) classes and functions.

Functions

compute_ema

Updates the weights of ema_model to be closer to the weights of model according to an exponential weighted average.

Classes

EMA

Maintains a shadow model with weights that follow the exponential moving average of the trained model weights.

class composer.algorithms.ema.ema.EMA(half_life, update_interval=None, train_with_ema_weights=False)[source]#

Bases: composer.core.algorithm.Algorithm

Maintains a shadow model with weights that follow the exponential moving average of the trained model weights.

Weights are updated according to

\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]

Where the smoothing is determined from half_life according to

\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]

Model evaluation is done with the moving average weights, which can result in better generalization. Because of the shadow models, EMA triples the modelโ€™s memory consumption. Note that this does not mean that the total memory required doubles, since stored activations and the optimizer state are not duplicated. EMA also uses a small amount of extra compute to update the moving average weights.

See the Method Card for more details.

Parameters
  • half_life (str) โ€“ The time string specifying the half life for terms in the average. A longer half life means old information is remembered longer, a shorter half life means old information is discared sooner. A half life of 0 means no averaging is done, an infinite half life means no update is done. Currently only units of epoch (โ€˜epโ€™) and batch (โ€˜baโ€™). Value must be an integer.

  • update_interval (str, optional) โ€“ The time string specifying the period at which updates are done. For example, an update_interval='1ep' means updates are done every epoch, while update_interval='10ba' means updates are done once every ten batches. Units must match the units used to specify half_life. If not specified, update_interval will default to 1 in the units of half_life. Value must be an integer. Default: None.

  • train_with_ema_weights (bool, optional) โ€“ An experimental feature that uses the ema weights as the training weights. In most cases should be left as False. Default False.

Example

from composer.algorithms import EMA
algorithm = EMA(half_life='50ba', update_interval='1ba')
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration="1ep",
    algorithms=[algorithm],
    optimizers=[optimizer]
)
get_ema_model(model)[source]#

Copies ema model parameters and buffers to the input model and returns it.

Parameters

model (Module) โ€“ the model to convert into the ema model.

Returns

model (torch.nn.Module) โ€“ the input model with parameters and buffers replaced with the averaged parameters and buffers.

class composer.algorithms.ema.ema.ShadowModel(model)[source]#

A shadow model that tracks parameters and buffers from an original source model.

Parameters

model (Module) โ€“ the source model containing the parameters and buffers to shadow.

composer.algorithms.ema.ema.compute_ema(model, ema_model, smoothing=0.99)[source]#

Updates the weights of ema_model to be closer to the weights of model according to an exponential weighted average. Weights are updated according to

\[W_{ema_model}^{(t+1)} = smoothing\times W_{ema_model}^{(t)}+(1-smoothing)\times W_{model}^{(t)} \]

The update to ema_model happens in place.

The half life of the weights for terms in the average is given by

\[t_{1/2} = -\frac{\log(2)}{\log(smoothing)} \]

Therefore to set smoothing to obtain a target half life, set smoothing according to

\[smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] \]
Parameters
  • model (Module) โ€“ the model containing the latest weights to use to update the moving average weights.

  • ema_model (Module) โ€“ the model containing the moving average weights to be updated.

  • smoothing (float, optional) โ€“ the coefficient representing the degree to which older observations are kept. Must be in the interval \((0, 1)\). Default: 0.99.

Example

import composer.functional as cf
from torchvision import models
model = models.resnet50()
ema_model = models.resnet50()
cf.compute_ema(model, ema_model, smoothing=0.9)