Sequence Length Warmup
Tags: Method
, Autoregressive Language Modeling
, Masked Language Modeling
, NLP
, Warmup
, Curriculum
, Speedup
, Decreased Wall Clock Time
tl;dr
Sequence Length Warmup warms up the sequence length (number of tokens)
from a min_seq_length
to a max_seq_length
over some duration of
training. The underlying motivation is that sequence length is a proxy
for the difficulty of an example. Sequence Length Warmup is able to reduce
training time by ~1.5x while still achieving the same loss as baseline
models.
Hyperparameters
duration
- The fraction of training that the warmup should be applied for.min_seq_length
- The initial sequence length.max_seq_length
- The final sequence length. Used for the rest of training.step_size
- The number of tokens to increase the sequence length by at each step. Multiples of 8 are preferred in order to enable hardware acceleration.truncate
- How the sequence length adjustment is achieved.False
reshapes the data tensor, creating new samples out of the extra tokens.True
truncates the tensor, discarding the extra tokens.
Applicable Settings
Sequence Length Warmup as implemented in Composer applies to language modeling tasks, including autoregressive language modeling and masked language modeling.
Effects
Our experiments found that Sequence Length Warmup could speed up training by a factor of ~1.5x while achieving the same loss. The original authors of the paper claim that Sequence Length Warmup reduces the outliers in Adam’s (Kingma and Ba) variance term, which permits training on larger batch sizes and larger learning rates without divergence.
Implementation Details
Warmup Implementation
We implement this as a processing step during the forward pass, where we can either:
Truncate the tensor at the sequence length specified by the warmup schedule.
Reshape the tensor to the sequence length specified by the warmup, which allocates the extra tokens along the batch dimension.
Example when truncate = True
and seq_len = 8
:
Original Input (2 samples):
We choose to go to the moon. We choose to go to the moon in this decade and do the other things, not because they are easy, but because they are hard, because that goal will serve to organize and measure the best of our energies and skills.
It is for these reasons that I regard the decision last year to shift our efforts in space from low to high gear as among the most important decisions that will be made during my incumbency in the office of the Presidency.
Transformed Inputs (2 samples):
We choose to go to the moon.
It is for these reasons that I regard
Example when truncate = False
and seq_len = 8
:
Original Input (2 samples):
We choose to go to the moon. We choose to go to the moon in this decade and do the other things, not because they are easy, but because they are hard, because that goal will serve to organize and measure the best of our energies and skills, because that challenge
It is for these reasons that I regard the decision last year to shift our efforts in space from low to high gear as among the most important decisions that will be made during my incumbency in the office of the Presidency.
Transformed Inputs (14 samples):
We choose to go to the moon.
We choose to go to the moon in
this decade and do the other things
not because they are easy, but because
they are hard, because that goal will
serve to organize and measure the best of
our energies and skills, because that challenge
It is for these reasons that I regard
the decision last year to shift our efforts
in space from low to high gear as
among the most important decisions that will be
made during my incumbency in the office
of the Presidency. In the last 24
hours we have seen facilities now being created
Avoiding Out-Of-Memory Errors
Sequence Length Warmup starts with a small sequence length and gradually increases it. However, as a result, it constantly requires PyTorch to expand its memory cache allocation with new buffers as larger tensor sizes are consistently being streamed in.
In order to address this we create dummy inputs to the model, perform a forward and backward pass, and zero out the gradients. We do this without taking any scheduler or optimization steps. This permits PyTorch to allocate buffers for the maximum possible sequence length, and help avoid downstream out-of-memory errors.
Suggested Hyperparameters
We swept the duration
from 0.0
to 0.9
in increments of
0.1
across the GPT-2 52M
model, and found that running the
sequence length warmup for 30% of training leads to the fastest wall clock time to
reach the same loss. This corroborates the suggested hyperparameters in
the paper,
Considerations
Sequence length warmup is a form of curriculum learning, a category of techniques that present samples in a structured or organized order, such as by difficulty. Accordingly, it may compose poorly with other curriculum learning techniques such as batch-size warmup, which is used in the GPT-3 paper.
Composition
This method composes well with ALiBi (Press et al., 2021), a method that enables good extrapolation from shorter training sequence lengths to longer evaluation sequence lengths.
Attribution
Curriculum Learning: A Regularization Method for Efficient and Stable Billion-Scale GPT Model Pre-Training by Conglong Li, Minjia Zhang, and Yuxiong He. Posted to arXiv in 2021.
Code
- class composer.algorithms.seq_length_warmup.SeqLengthWarmup(duration: float = 0.3, min_seq_length: int = 8, max_seq_length: int = 1024, step_size: int = 8, truncate: bool = True)[source]
Progressively increases the sequence length during training.
Changes the sequence length of all tensors in the input batch. The sequence length increases from
min_seq_length
tomax_seq_length
in steps ofstep_size
during the firstduration
fraction of training.The sequence length is then kept at
max_seq_length
for the rest of training.Tensors are either truncated (
truncate=True
) or reshaped to create new examples from the extra tokens (truncate=False
).Note
step_size
should be a multiple of eight for GPUsNote
Variable input lengths can create CUDA OOM errors. To avoid this, we follow PyTorch notes and pre-allocate the memory with a blank forward and backward pass.
- Parameters
duration (float) – fraction of total training for sequential length learning.
min_seq_length (int) – Minimum sequence length to start the warmup.
max_seq_length (int) – Maximum sequence length to stop the warmup.
step_size (int) – Step size of sequence length.
truncate (bool) – Truncate tensors or reshape extra tokens to new examples
- apply(event: composer.core.event.Event, state: composer.core.state.State, logger: composer.core.logging.logger.Logger) Optional[int] [source]
Applies on
Event.TRAINING_START
to allocate PyTorch cache, orEvent.AFTER_DATALOADER
to apply the sequence length warmup to the input batch.- Parameters
event (
Event
) – The current event.state (
State
) – The current state.logger (
Logger
) – A logger to use for logging algorithm-specific metrics.
- Returns
int or None – exit code that is stored in
Trace
and made accessible for debugging.
- match(event: composer.core.event.Event, state: composer.core.state.State) bool [source]
Sequence Length Warmup matches on two events:
Event.TRAINING_START
in order to run a blank forward and backward pass and allocate PyTorch cache.Event.AFTER_DATALOADER
in order to apply the sequence length warmup before the forward pass.
- Parameters
event (
Event
) – The current event.state (
State
) – The current state.
- Returns
bool – True if this algorithm should run now.
- class composer.algorithms.seq_length_warmup.apply_seq_length_warmup(batch: Dict[str, torch.Tensor], curr_seq_len: int, truncate: bool)[source]
Progressively increases the sequence length during training.
Changes the sequence length of all tensors in the provided dictionary to
curr_seq_len
, by either truncating the tensors (truncate=True
) or reshaping the tensors to create new examples from the extra tokens (truncate=False
).The schedule for
curr_seq_len
over training time should be managed out of this function.- Parameters
- Returns
batch – a Mapping of input tensors to the model, where all tensors have curr_seq_len in the second dimension.