Source code for composer.models.transformer_hparams
# Copyright 2021 MosaicML. All Rights Reserved.
"""General `YAHP <https://docs.mosaicml.com/projects/yahp/en/stable/README.html>`_ interface for
ComposerTransformers."""
from abc import ABC
from dataclasses import dataclass
from typing import Dict, Optional
import yahp as hp
from composer.core.types import JSON
from composer.models.model_hparams import ModelHparams
__all__ = ["TransformerHparams"]
[docs]@dataclass
class TransformerHparams(ModelHparams, ABC):
"""Defines the necessary hyparameters for a Transformer base module.
Args:
pretrained_model_name (Optional[str]): "Pretrained model name to pull from Huggingface Model Hub."
model_config (Dict[str, JSON]): A dictionary providing a HuggingFace model configuration.
tokenizer_name (str): The tokenizer used for this model,
necessary to assert required model inputs.
use_pretrained (bool, optional): Whether to initialize the model with the pretrained weights. Default: ``False``
gradient_checkpointing (bool, optional): Use gradient checkpointing. Default: ``False``.
"""
tokenizer_name: Optional[str] = hp.optional("Tokenizer name to pull from Huggingface Model Hub.", default=None)
pretrained_model_name: Optional[str] = hp.optional(
doc="Pretrained model name to pull from Huggingface Model Hub.",
default=None,
)
model_config: Dict[str, JSON] = hp.optional(doc="A dictionary providing a HuggingFace model configuration.",
default_factory=dict)
use_pretrained: bool = hp.optional("Whether to initialize the model with the pretrained weights.", default=False)
gradient_checkpointing: bool = hp.optional("Whether to enable gradient checkpointing.", default=False)
def validate(self):
if self.pretrained_model_name is None and self.model_config == {}:
raise Exception("One of pretrained_model_name or model_config needed.")
if self.pretrained_model_name is not None and self.model_config != {}:
raise Exception("Only one of pretrained_model_name or model_config can be provided.")
if self.use_pretrained and self.model_config:
raise Exception("A model cannot load pretrained weights from configuration.")