# Copyright 2021 MosaicML. All Rights Reserved.
"""Hyperparameters for schedulers."""
from abc import ABC
from dataclasses import asdict, dataclass
from typing import List, Optional, Type
import yahp as hp
from composer.optim.scheduler import (ComposerScheduler, ConstantScheduler, CosineAnnealingScheduler,
CosineAnnealingWarmRestartsScheduler, CosineAnnealingWithWarmupScheduler,
ExponentialScheduler, LinearScheduler, LinearWithWarmupScheduler,
MultiStepScheduler, MultiStepWithWarmupScheduler, PolynomialScheduler,
StepScheduler)
__all__ = [
"SchedulerHparams", "StepSchedulerHparams", "MultiStepSchedulerHparams", "ConstantSchedulerHparams",
"LinearSchedulerHparams", "ExponentialSchedulerHparams", "CosineAnnealingSchedulerHparams",
"CosineAnnealingWarmRestartsSchedulerHparams", "PolynomialSchedulerHparams", "MultiStepWithWarmupSchedulerHparams",
"LinearWithWarmupSchedulerHparams", "CosineAnnealingWithWarmupSchedulerHparams"
]
[docs]@dataclass
class SchedulerHparams(hp.Hparams, ABC):
"""Base class for scheduler hyperparameter classes.
Scheduler parameters that are added to :class:`~composer.trainer.trainer_hparams.TrainerHparams` (e.g. via YAML or
the CLI) are initialized in the training loop.
"""
_scheduler_cls = None # type: Optional[Type[ComposerScheduler]]
[docs] def initialize_object(self) -> ComposerScheduler:
"""Initializes the scheduler."""
if self._scheduler_cls is None:
raise NotImplementedError(f"Cannot initialize {self} because `_scheduler_cls` is undefined.")
# Expected no arguments to "ComposerScheduler" constructor
return self._scheduler_cls(**asdict(self)) # type: ignore
[docs]@dataclass
class StepSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.StepScheduler` scheduler.
See :class:`~.StepScheduler` for documentation.
Args:
step_size (str, optional): See :class:`~.StepScheduler`.
gamma (float, optional): See :class:`~.StepScheduler`.
"""
step_size: str = hp.required(doc="Time between changes to the learning rate.")
gamma: float = hp.optional(default=0.1, doc="Multiplicative decay factor.")
_scheduler_cls = StepScheduler
[docs]@dataclass
class MultiStepSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.MultiStepScheduler` scheduler.
See :class:`~.MultiStepScheduler` for documentation.
Args:
milestones (List[str]): See :class:`~.MultiStepScheduler`.
gamma (float, optional): See :class:`~.MultiStepScheduler`.
"""
milestones: List[str] = hp.required(doc="Times at which the learning rate should change.")
gamma: float = hp.optional(default=0.1, doc="Multiplicative decay factor.")
_scheduler_cls = MultiStepScheduler
[docs]@dataclass
class ConstantSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.ConstantScheduler` scheduler.
See :class:`~.ConstantScheduler` for documentation.
Args:
alpha (float, optional): See :class:`~.ConstantScheduler`.
t_max (str, optional): See :class:`~.ConstantScheduler`.
"""
alpha: float = hp.optional(default=1.0, doc="Learning rate multiplier to maintain while this scheduler is active.")
t_max: str = hp.optional(default="1dur", doc="Duration of this scheduler.")
_scheduler_cls = ConstantScheduler
[docs]@dataclass
class LinearSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.LinearScheduler` scheduler.
See :class:`~.LinearScheduler` for documentation.
Args:
alpha_i (float, optional): See :class:`~.LinearScheduler`.
alpha_f (float, optional): See :class:`~.LinearScheduler`.
t_max (str, optional): See :class:`~.LinearScheduler`.
"""
alpha_i: float = hp.optional("Initial learning rate multiplier.", default=1.0)
alpha_f: float = hp.optional("Final learning rate multiplier.", default=0.0)
t_max: str = hp.optional(default="1dur", doc="Duration of this scheduler.")
_scheduler_cls = LinearScheduler
[docs]@dataclass
class ExponentialSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.ExponentialScheduler` scheduler.
See :class:`~.ExponentialScheduler` for documentation.
Args:
gamma (float): See :class:`~.ExponentialScheduler`.
decay_period (str, optional): See :class:`~.ExponentialScheduler`.
"""
gamma: float = hp.required(doc="Multiplicative decay factor.")
decay_period: str = hp.optional(default="1ep", doc="Decay period.")
_scheduler_cls = ExponentialScheduler
[docs]@dataclass
class CosineAnnealingSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.CosineAnnealingScheduler` scheduler.
See :class:`~.CosineAnnealingScheduler` for documentation.
Args:
t_max (str, optional): See :class:`~.CosineAnnealingScheduler`.
alpha_f (float, optional): See :class:`~.CosineAnnealingScheduler`.
"""
t_max: str = hp.optional(default="1dur", doc="Duration of this scheduler.")
alpha_f: float = hp.optional(default=0.0, doc="Learning rate multiplier to decay to.")
_scheduler_cls = CosineAnnealingScheduler
[docs]@dataclass
class CosineAnnealingWarmRestartsSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.CosineAnnealingWarmRestartsScheduler` scheduler.
See :class:`~.CosineAnnealingWarmRestartsScheduler` for documentation.
Args:
t_0 (str, optional): See :class:`~.CosineAnnealingWarmRestartsScheduler`.
alpha_f (float, optional): See :class:`~.CosineAnnealingWarmRestartsScheduler`.
t_mult (float, optional): See :class:`~.CosineAnnealingWarmRestartsScheduler`.
"""
t_0: str = hp.optional(default="1dur", doc="The period of the first cycle.")
alpha_f: float = hp.optional(default=0.0, doc="Learning rate multiplier to decay to.")
t_mult: float = hp.optional(default=1.0, doc="The multiplier for the duration of successive cycles.")
_scheduler_cls = CosineAnnealingWarmRestartsScheduler
[docs]@dataclass
class PolynomialSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.PolynomialScheduler` scheduler.
See :class:`~.PolynomialScheduler` for documentation.
Args:
power (float): See :class:`~.PolynomialScheduler`.
t_max (str, optional): See :class:`~.PolynomialScheduler`.
alpha_f (float, optional): See :class:`~.PolynomialScheduler`.
"""
power: float = hp.required(doc="The exponent to be used for the proportionality relationship.")
t_max: str = hp.optional(default="1dur", doc="Duration of this scheduler.")
alpha_f: float = hp.optional(default=0.0, doc="Learning rate multiplier to decay to.")
_scheduler_cls = PolynomialScheduler
[docs]@dataclass
class MultiStepWithWarmupSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.MultiStepWithWarmupScheduler` scheduler.
See :class:`~.MultiStepWithWarmupScheduler` for documentation.
Args:
t_warmup (str,): See :class:`~.MultiStepWithWarmupScheduler`.
milestones (List[str]): See :class:`~.MultiStepWithWarmupScheduler`.
gamma (float, optional): See :class:`~.MultiStepWithWarmupScheduler`.
"""
t_warmup: str = hp.required(doc="Warmup time.")
milestones: List[str] = hp.required(doc="Times at which the learning rate should change.")
gamma: float = hp.optional(default=0.1, doc="Multiplicative decay factor.")
_scheduler_cls = MultiStepWithWarmupScheduler
[docs]@dataclass
class LinearWithWarmupSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.LinearWithWarmupScheduler` scheduler.
See :class:`~.LinearWithWarmupScheduler` for documentation.
Args:
t_warmup (str): See :class:`~.LinearWithWarmupScheduler`.
alpha_i (float, optional): See :class:`~.LinearWithWarmupScheduler`.
alpha_f (float, optional): See :class:`~.LinearWithWarmupScheduler`.
t_max (str, optional): See :class:`~.LinearWithWarmupScheduler`.
"""
t_warmup: str = hp.required(doc="Warmup time.")
alpha_i: float = hp.optional("Initial learning rate multiplier.", default=1.0)
alpha_f: float = hp.optional("Final learning rate multiplier.", default=0.0)
t_max: str = hp.optional(default="1dur", doc="Duration of this scheduler.")
_scheduler_cls = LinearWithWarmupScheduler
[docs]@dataclass
class CosineAnnealingWithWarmupSchedulerHparams(SchedulerHparams):
"""Hyperparameters for the :class:`~.CosineAnnealingWithWarmupScheduler` scheduler.
See :class:`~.CosineAnnealingWithWarmupScheduler` for documentation.
Args:
t_warmup (str): See :class:`~.CosineAnnealingWithWarmupScheduler`.
t_max (str, optional): See :class:`~.CosineAnnealingWithWarmupScheduler`.
alpha_f (float, optional): See :class:`~.CosineAnnealingWithWarmupScheduler`.
"""
t_warmup: str = hp.required(doc="Warmup time.")
t_max: str = hp.optional(default="1dur", doc="Duration of this scheduler.")
alpha_f: float = hp.optional(default=0.0, doc="Learning rate multiplier to decay to.")
_scheduler_cls = CosineAnnealingWithWarmupScheduler