composer.algorithms.sam.sam#

composer.algorithms.sam.sam

Functions

ensure_tuple

Converts x into a tuple.

Classes

Algorithm

Base class for algorithms.

Event

Enum to represent events in the training loop.

Logger

Logger routes metrics to the LoggerCallback.

SAM

Adds sharpness-aware minimization (Foret et al, 2020) by wrapping an existing optimizer with a SAMOptimizer.

SAMOptimizer

Wraps an optimizer with sharpness-aware minimization (Foret et al, 2020).

State

The state of the trainer.

Attributes

  • Optional

  • annotations

  • log

class composer.algorithms.sam.sam.SAM(rho=0.05, epsilon=1e-12, interval=1)[source]#

Bases: composer.core.algorithm.Algorithm

Adds sharpness-aware minimization (Foret et al, 2020) by wrapping an existing optimizer with a SAMOptimizer.

Parameters
  • rho (float, optional) โ€“ The neighborhood size parameter of SAM. Must be greater than 0. Default: 0.05.

  • epsilon (float, optional) โ€“ A small value added to the gradient norm for numerical stability. Default: 1e-12.

  • interval (int, optional) โ€“ SAM will run once per interval steps. A value of 1 will cause SAM to run every step. Steps on which SAM runs take roughly twice as much time to complete. Default: 1.

apply(event, state, logger)[source]#

Applies SAM by wrapping the base optimizer with the SAM optimizer.

Parameters
  • event (Event) โ€“ the current event

  • state (State) โ€“ the current trainer state

  • logger (Logger) โ€“ the training logger

match(event, state)[source]#

Run on Event.INIT.

Parameters
  • event (Event) โ€“ The current event.

  • state (State) โ€“ The current state.

Returns

bool โ€“ True if this algorithm should run now

class composer.algorithms.sam.sam.SAMOptimizer(base_optimizer, rho=0.05, epsilon=1e-12, interval=1, **kwargs)[source]#

Bases: torch.optim.optimizer.Optimizer

Wraps an optimizer with sharpness-aware minimization (Foret et al, 2020). See SAM for details.

Implementation based on https://github.com/davda54/sam