composer.algorithms.alibi.alibi#

Core ALiBi classes and functions.

Functions

apply_alibi

Removes position embeddings and replaces the attention function and attention mask according as per Alibi.

Classes

Alibi

ALiBi (Attention with Linear Biases; Press et al, 2021) dispenses with position embeddings and instead directly biases attention matrices such that nearby tokens attend to one another more strongly.

class composer.algorithms.alibi.alibi.Alibi(position_embedding_attribute, attention_module_name, attr_to_replace, alibi_attention, mask_replacement_function=None, heads_per_layer=None, max_sequence_length=8192, train_sequence_length_scaling=0.25)[source]#

Bases: composer.core.algorithm.Algorithm

ALiBi (Attention with Linear Biases; Press et al, 2021) dispenses with position embeddings and instead directly biases attention matrices such that nearby tokens attend to one another more strongly.

ALiBi yields excellent extrapolation to unseen sequence lengths compared to other position embedding schemes. We leverage this extrapolation capability by training with shorter sequence lengths, which reduces the memory and computation load.

This algorithm runs on INIT to modify the model, before the model has been moved to accelerators. It also runs on AFTER_DATALOADER to modify the shape of a batch of data, after the model and data have been moved to accelerators.

See the Method Card for more details.

Example:

from composer.algorithms import Alibi
from composer.trainer import Trainer

alibi = Alibi(position_embedding_attribute="module.transformer.wpe",
              attention_module_name="transformers.models.gpt2.modeling_gpt2.GPT2Attention"
              attr_to_replace="_attn",
              alibi_attention="composer.algorithms._gpt2_alibi._attn",
              mask_replacement_function="composer.algorithms.alibi.gpt2_alibi.enlarge_mask"
              max_sequence_length=8192)

trainer = Trainer(model=model,
                  train_dataloader=train_dataloader,
                  max_duration="1ep",
                  algorithms=[alibi])
Parameters
  • position_embedding_attribute (str) โ€“ Attribute for position embeddings. For example in HuggingFaceโ€™s GPT2, the position embeddings are 'transformer.wpe'.

  • attention_module_name (str) โ€“ Module/class that will have its self-attention function replaced. For example, in HuggingFaceโ€™s GPT, the self-attention module is 'transformers.models.gpt2.modeling_gpt2.GPT2Attention'.

  • attr_to_replace (str) โ€“ Attribute that self-attention function will replace. For example, in HuggingFaceโ€™s GPT2, the self-attention function is '_attn'.

  • alibi_attention (str) โ€“ Path to new self-attention function in which ALiBi is implemented. Used to replace {attention_module}.{attr_to_replace}. Example: 'composer.algorithms.alibi._gpt2_alibi._attn'.

  • mask_replacement_function (Union[str, None]) โ€“ Path to function to replace modelโ€™s attention mask. This can be necessary if evaluating on sequence lengths longer than the model was initialized to accommodate. Takes positional arguments module and max_sequence_length. For example, 'composer.algorithms.alibi._gpt2_alibi.enlarge_mask'. Default = None, which means no modification of the modelโ€™s default attention mask.

  • heads_per_layer (int, optional) โ€“ Number of attention heads per layer

  • max_sequence_length (int) โ€“ Maximum sequence length that the model will be able to accept. This is sometimes necessary for evaluating on sequence lengths longer than the model was initialized to accommodate. Default: 8192.

  • train_sequence_length_scaling (float, optional) โ€“ Amount by which to scale training sequence length. One batch of training data will be reshaped from shape \((sequence\_length, batch)\) to \((sequence\_length \times train\_sequence\_length\_scaling, \frac{batch}{train\_sequence\_length\_scaling})\). Default: 0.25.

composer.algorithms.alibi.alibi.apply_alibi(model, heads_per_layer, max_sequence_length, position_embedding_attribute, attention_module, attr_to_replace, alibi_attention, mask_replacement_function=None, optimizers=None)[source]#

Removes position embeddings and replaces the attention function and attention mask according as per Alibi. Note that the majority of the training speed-up from using ALiBi comes from being able to train on shorter sequence lengths; this function does not scale the training sequence length as Alibi does, so little speedup will be observed from using it alone. See the Method Card for more details. This function should be called after the model is instantiated and before training begins.

Example:

import composer.functional as cf

from composer.algorithms.alibi.gpt2_alibi import _attn
from composer.algorithms.alibi.gpt2_alibi import enlarge_mask
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention

cf.apply_alibi(model=model,
                heads_per_layer=12,
                max_sequence_length=8192,
                position_embedding_attribute="module.transformer.wpe",
                attention_module=GPT2Attention,
                attr_to_replace="_attn",
                alibi_attention=_attn,
                mask_replacement_function=enlarge_mask)
Parameters
  • model (Module) โ€“ Model to transform.

  • heads_per_layer (int) โ€“ Number of attention heads per layer.

  • max_sequence_length (int) โ€“ See Alibi.

  • position_embedding_attribute (str) โ€“ See Alibi.

  • attention_module (Module) โ€“ Module/class that will have its self-attention function replaced. For example, in HuggingFaceโ€™s GPT, the self-attention module is transformers.models.gpt2.modeling_gpt2.GPT2Attention.

  • attr_to_replace (str) โ€“ See Alibi.

  • alibi_attention (Callable) โ€“ Path to new self-attention function in which ALiBi is implemented. Used to replace {attention_module}.{attr_to_replace}. Example: composer.algorithms.alibi._gpt2_alibi._attn.

  • mask_replacement_function ([Callable[[Module, int], Module]], optional) โ€“ Function to replace modelโ€™s attention mask. This can be necessary for evaluating on sequence lengths longer than the model was initialized to accommodate. Takes positional arguments module and max_sequence_length. For example, composer.algorithms.alibi._gpt2_alibi.enlarge_mask. Default: None, which means no modification of the modelโ€™s default attention mask.

  • optimizers (Optimizers, optional) โ€“

    Existing optimizers bound to model.parameters(). All optimizers that have already been constructed with model.parameters() must be specified here so they will optimize the correct parameters. Default: None.

    If the optimizer(s) are constructed after calling this function, then it is safe to omit this parameter. These optimizers will see the correct model parameters.

Returns

None