Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿ“– NLP Models#

This tutorial will demonstrate how to fine-tune a pretrained HuggingFace transformer using the composer library! Composer provides a highly optimized and functional training loop and the ability to compose several methods that can accelerate training.

We will focus on fine-tuning a pretrained BERT-base model on the Stanford Sentiment Treebank v2 (SST-2) dataset. After fine-tuning, the BERT model should be able to determine if a setence has positive or negative sentiment.

Letโ€™s do this ๐Ÿš€

Install Composer#

To develop NLP models with Composer, weโ€™ll need to install Composer with the NLP dependencies. If you havenโ€™t already, run:

[ ]:
%pip install 'mosaicml[nlp]'

Defining a Composer Model#

The first task is to create a composer model. A composer model defines four components for the Composer trainer: - forward() - parses the dataloader output for the modelโ€™s forward function and extracts the necessary components of the modelโ€™s output for loss calculation. - loss() - computes the loss for the current batch using the model and dataloader outputs. - validate() - parses the dataloader and model output for torchmetrics. - metrics() - defines the torchmetrics to use during training/validation.

[ ]:
import transformers
from torchmetrics import Accuracy
from torchmetrics.collections import MetricCollection
from composer.models.base import ComposerModel
from composer.metrics import CrossEntropy

# Define a Composer Model
class ComposerBERT(ComposerModel):
    def __init__(self, model):
        super().__init__()
        self.module = model

        # Metrics
        self.train_loss = CrossEntropy()
        self.val_loss = CrossEntropy()

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    def forward(self, batch):
        output = self.module(**batch)
        return output

    def loss(self, outputs, batch):
        return outputs['loss']

    def validate(self, batch):
        labels = batch.pop('labels')
        output = self.forward(batch)
        output = output['logits']
        return (output, labels)

    def metrics(self, train: bool = False):
        return MetricCollection([self.train_loss, self.train_acc]) \
     if train else MetricCollection([self.val_loss, self.val_acc])

# Create a BERT sequence classification model using HuggingFace transformsers
model = transformers.AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) # in BERT hparams

# Package as a composer model
composer_model = ComposerBERT(model)

Creating dataloaders#

Next, we will download and tokenize the SST-2 datasets.

[ ]:
import datasets
from multiprocessing import cpu_count

# Create BERT tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased') # from transformer_shared
def tokenize_function(sample):
    return tokenizer(
        text=sample['sentence'],
        padding="max_length",
        max_length=256,
        truncation=True
    )

# Tokenize SST-2
sst2_dataset = datasets.load_dataset("glue", "sst2")
tokenized_sst2_dataset = sst2_dataset.map(tokenize_function,
                                          batched=True,
                                          num_proc=cpu_count(),
                                          batch_size=1000,
                                          remove_columns=['idx', 'sentence'])

# Split dataset into train and validation sets
train_dataset = tokenized_sst2_dataset["train"]
eval_dataset = tokenized_sst2_dataset["validation"]

Here, we will create a PyTorch DataLoader for each of the datasets generated in the previous block.

[ ]:
from torch.utils.data import DataLoader
data_collator = transformers.data.data_collator.default_data_collator
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False, drop_last=False, collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset,batch_size=16, shuffle=False, drop_last=False, collate_fn=data_collator)

To use the composer Trainer, we need to define a split_batch function. This function defines how to split the dataloader output into several โ€œmicrobatchesโ€. Microbatchs are chunks of the batch that were divided based on the amount of gradient accumulation used.

[ ]:
from composer.core import DataSpec

def split_batch_dict(batch, n_microbatches: int):
    chunked = {k: v.chunk(n_microbatches) for k, v in batch.items()}
    num_chunks = len(list(chunked.values())[0])
    return [{k: v[idx] for k, v in chunked.items()} for idx in range(num_chunks)]

train_dataspec = DataSpec(dataloader=train_dataloader,
                          split_batch=split_batch_dict)
eval_dataspec = DataSpec(dataloader=eval_dataloader,
                         split_batch=split_batch_dict)

Optimizers and Learning Rate Schedulers#

The last setup step is to create an optimizer and a learning rate scheduler. We will use PyTorchโ€™s AdamW optimizer and linear learning rate scheduler since these are typically used to fine-tune BERT on tasks such as SST-2.

[ ]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR

optimizer = AdamW(
    params=composer_model.parameters(),
    lr=3e-5, betas=(0.9, 0.98),
    eps=1e-6, weight_decay=3e-6
)
linear_lr_decay = LinearLR(
    optimizer, start_factor=1.0,
    end_factor=0, total_iters=150
)

Composer Trainer#

We will now specify a Composer Trainer object and run our training! Trainer has many arguments that are described in our documentation, but letโ€™s discuss the less obvious arguments used below: - max_duration - a string specifying how long to train, either in terms of batches (e.g. โ€˜10baโ€™ is 10 batches) or epochs (e.g. โ€˜1epโ€™ is 1 epoch). - schedulers - a list of PyTorch learning rate schedulers that will be composed together. - device - specifies if the training will be done on CPU or GPU by using โ€˜cpuโ€™ or โ€˜gpuโ€™, respectively. - train_subset_num_batches - specifies the number of training batches to use for each epoch. This is not a necessary argument but is useful for quickly testing code. - precision - whether to do the training in full precision โ€˜fp32โ€™ or mixed precision โ€˜ampโ€™. Mixed precision provides an almost 2x speedup in training time on certain hardware. If you get a P100, try precision='amp'! - seed - sets the random seed for the training run, so the results are reproducible!

[ ]:
import torch
from composer import Trainer

# Create Trainer Object
trainer = Trainer(
    model=composer_model,
    train_dataloader=train_dataspec,
    eval_dataloader=eval_dataspec,
    max_duration="1ep",
    optimizers=optimizer,
    schedulers=[linear_lr_decay],
    device='gpu' if torch.cuda.is_available() else 'cpu',
    train_subset_num_batches=150,
    precision='fp32',
    seed=17
)
# Start training
trainer.fit()

Visualizing Results#

Our model reaches almost 86% accuracy with only 100 iterations of training! Letโ€™s visualize a few samples from the validation set to see how our model performs.

[ ]:
eval_batch = next(iter(eval_dataloader))

# Move batch to gpu
eval_batch = {k: v.cuda() if torch.cuda.is_available() else v for k, v in eval_batch.items()}
with torch.no_grad():
    predictions = composer_model(eval_batch)["logits"].argmax(dim=1)

# Visualize only 5 samples
predictions = predictions[:6]

label = ['negative', 'positive']
for i, prediction in enumerate(predictions[:6]):
    sentence = sst2_dataset["validation"][i]["sentence"]
    correct_label = label[sst2_dataset["validation"][i]["label"]]
    prediction_label = label[prediction]
    print(f"Sample: {sentence}")
    print(f"Label: {correct_label}")
    print(f"Prediction: {prediction_label}")
    print()

Conclusion#

This tutorial showed how to use the Composer Trainer to fine-tune a pre-trained BERT on a subset of the SST-2 dataset. We focused on the Composerโ€™s basic functionality, but there are many more tools such as easy to use gradient accumulation and multi-GPU training! Check out many other features at our documentation.