composer.models.transformer_shared#

The ComposerModel base interface for Transformers.

Classes

ComposerTransformer

The ComposerModel base interface for Transformers.

class composer.models.transformer_shared.ComposerTransformer(module, config, model_inputs, gradient_checkpointing=False)[source]#

Bases: composer.models.base.ComposerModel

The ComposerModel base interface for Transformers.

Works with Hugging Face Transformers.

Parameters
  • module (PreTrainedModel) โ€“ An instance of PreTrainedModel that contains the forward pass function.

  • config (PretrainedConfig) โ€“ The PretrainedConfig object that stores information about the model hyperparameters.

  • model_inputs (set) โ€“ The dictionary keys that should be required to be fed into the modelโ€™s forward function.

  • gradient_checkpointing (bool, optional) โ€“ Use gradient checkpointing. Default: False.

forward(batch)[source]#

Run the forward pass of the model.

Parameters

batch (Batch) โ€“ A dictionary of Dict[str, Tensor] of inputs that the model expects, as found in ComposerTransformer.get_model_inputs().

Returns

output โ€“ A dictionary of model outputs as a Mapping. It will include the loss if labels is passed as an input.

get_model_inputs()[source]#

Returns a set of inputs that the model expects in the forward pass.

If an algorithm wants to interact with the model inputs (for instance, popping the labels for a custom loss fn, or adding attention head masks for head pruning, it must access self.set_model_inputs().

Returns

model_inputs โ€“ The set of keys that are expected in the Mapping used to compute the forward pass.

loss(outputs, batch)[source]#

Computes the loss of the tensor from the output.

We donโ€™t implement this for the generic Transformer abstraction, since loss functions are model and objective specific. A single model architecture could use a myriad of loss functions which are better left expressed by the user.

Parameters
  • outputs (Mapping) โ€“ The dictionary output from the model. It could contain the loss as computed by Hugging Face, or algorithms can pop the labels from the input in case they modify the loss function.

  • batch (Batch) โ€“ The set of ground truth labels to use to compute the loss against.

Raises

NotImplementedError โ€“ A model-specific and task-specific loss function must be written.

validate(batch)[source]#

Runs the validation step.

Parameters

batch (Batch) โ€“ a dictionary of Dict[str, Tensor] of inputs that the model expects, as found in ComposerTransformer.get_model_inputs().

Returns

Tuple[Mapping, None] โ€“ A tuple containing the output from the forward pass. This is fed into directly into the output of ComposerModel.metrics().