composer.optim.decoupled_weight_decay#

Optimizers with weight decay decoupled from the learning rate.

These optimizers are based off of Decoupled Weight Decay Regularization, which proposes this decoupling. In general, it is recommended to use these optimizers over their native PyTorch equivalents.

Classes

DecoupledAdamW

Adam optimizer with the weight decay term decoupled from the learning rate.

DecoupledSGDW

SGD optimizer with the weight decay term decoupled from the learning rate.

class composer.optim.decoupled_weight_decay.DecoupledAdamW(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01, amsgrad=False)[source]#

Bases: torch.optim.adamw.AdamW

Adam optimizer with the weight decay term decoupled from the learning rate.

Argument defaults are copied from torch.optim.AdamW.

The standard AdamW optimizer explicitly couples the weight decay term with the learning rate. This ties the optimal value of weight_decay to lr and can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.

Parameters
  • params (list) โ€“ List of parameters to update.

  • lr (float, optional) โ€“ Learning rate. Default: 1e-3.

  • betas (tuple, optional) โ€“ Coefficients used for computing running averages of gradient and its square Default: (0.9, 0.999).

  • eps (float, optional) โ€“ Term added to the denominator to improve numerical stability. Default: 1e-8.

  • weight_decay (float, optional) โ€“ Decoupled weight decay factor. Default: 1e-2.

  • amsgrad (bool, optional) โ€“ Enables the amsgrad variant of Adam. Default: False.

static adamw(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, *, amsgrad, beta1, beta2, lr, initial_lr, weight_decay, eps)[source]#

Functional API that performs AdamW algorithm computation with decoupled weight decay.

Parameters
  • params (List[Tensor]) โ€“ List of parameters to update.

  • grads (List[Tensor]) โ€“ List of parameter gradients.

  • exp_avgs (List[Tensor]) โ€“ List of average gradients.

  • exp_avg_sqs (List[Tensor]) โ€“ List of average squared gradients.

  • max_exp_avg_sqs (List[Tensor]) โ€“ List of max average squared gradients for amsgrad updates.

  • state_steps (Iterable[int]) โ€“ List of steps taken for all parameters.

  • amsgrad (bool) โ€“ Enables amsgrad variant of Adam.

  • beta1 (float) โ€“ Coefficient for computing the moving average of gradient values.

  • beta2 (float) โ€“ Coefficient for computing the moving average of squared gradient values.

  • lr (float) โ€“ Learning rate.

  • initial_lr (float) โ€“ Initial learning rate.

  • weight_decay (float) โ€“ Factor for decoupled weight decay

  • eps (float) โ€“ Term added to the denominator to improve numerical stability.

step(closure=None)[source]#

Performs a single optimization step.

Parameters

closure (callable, optional) โ€“ A closure that reevaluates the model and returns the loss.

class composer.optim.decoupled_weight_decay.DecoupledSGDW(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)[source]#

Bases: torch.optim.sgd.SGD

SGD optimizer with the weight decay term decoupled from the learning rate.

Argument defaults are copied from torch.optim.SGD.

The standard SGD optimizer couples the weight decay term with the gradient calculation. This ties the optimal value of weight_decay to lr and can also hurt generalization in practice. For more details on why decoupling might be desirable, see Decoupled Weight Decay Regularization.

Parameters
  • params (list) โ€“ List of parameters to optimize or dicts defining parameter groups.

  • lr (float, optional) โ€“ Learning rate.

  • momentum (int, optional) โ€“ Momentum factor. Default: 0.

  • dampening (int, optional) โ€“ Dampening factor applied to the momentum. Default: 0.

  • weight_decay (int, optional) โ€“ Decoupled weight decay factor. Default: 0.

  • nesterov (bool, optional) โ€“ Enables Nesterov momentum updates. Default: False.

static sgdw(params, d_p_list, momentum_buffer_list, *, weight_decay, momentum, lr, initial_lr, dampening, nesterov)[source]#

Functional API that performs SGDW algorithm computation.

Parameters
  • params (list) โ€“ List of parameters to update

  • d_p_list (list) โ€“ List of parameter gradients

  • momentum_buffer_list (list) โ€“ List of momentum buffers

  • weight_decay (float) โ€“ Decoupled weight decay factor

  • momentum (float) โ€“ Momentum factor

  • lr (float) โ€“ Learning rate

  • initial_lr (float) โ€“ Initial learning rate

  • dampening (float) โ€“ Dampening factor for momentum update

  • nesterov (bool) โ€“ Enables Nesterov momentum updates

step(closure=None)[source]#

Performs a single optimization step.

Parameters

closure (callable, optional) โ€“ A closure that reevaluates the model and returns the loss.