Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿฉบ Image Segmentation#

In this notebook you will use Composer and PyTorch to segment pneumothorax (air around or outside of the lungs) from chest radiographic images. This dataset was originally released for a kaggle competition by the Society for Informatics in Medicine (SIIM).

Disclaimer: This example represents a minimal working baseline. In order to get competitive results this notebook must run for a long time.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to provide an executable example of a computer vision project in Composer from the ground up.

We will cover:

  • installing relevant packages

  • downloading the SIIM dataset from kaggle

  • cleaning and resampling the dataset

  • splitting data for validation

  • visualizing model inputs

  • training a baseline model with Composer

  • using Composer methods

  • next steps

Letโ€™s get started!

Setup#

Letโ€™s get started and configure our environment.

Install Dependencies#

If you havenโ€™t already, letโ€™s install the following dependencies, which are needed for this example:

[ ]:
%pip install kaggle pydicom git+https://github.com/qubvel/segmentation_models.pytorch opencv-python-headless jupyterlab-widgets

%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git

Kaggle Authentication#

To access the data you need a Kaggle Account - accept competition terms https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/data - download kaggle.json from https://www.kaggle.com/yourusername/account by clicking โ€œCreate new API tokenโ€ - make the kaggle.json file available to this notebook using the following code cells.

[ ]:
from ipywidgets import FileUpload
from IPython.display import display
uploader = FileUpload(accept='.json', multiple=True)
display(uploader)
[ ]:
import os

kaggle_folder = os.path.join(os.path.expanduser("~"), ".kaggle")
os.makedirs(kaggle_folder, exist_ok=True)
kaggle_config_file = os.path.join(kaggle_folder, "kaggle.json")
with open(kaggle_config_file, 'wb+') as output_file:
    for uploaded_filename in uploader.value:
        content = uploader.value[uploaded_filename]['content']
        output_file.write(content)

Download and unzip the data#

[ ]:
!kaggle datasets download -d seesee/siim-train-test
!unzip -q siim-train-test.zip -d .
!ls

Flatten Image Directories#

The original dataset is oddly nested. We flatten it out so the images are easier to access in our pytorch dataset.

/siim/dicom-images-train/id/id/id.dcm to /siim/dicom-images-train/id.dcm.

[ ]:
from pathlib import Path
from tqdm.auto import tqdm

train_images = list(Path('siim/dicom-images-train').glob('*/*/*.dcm'))
for image in tqdm(train_images):
    image.replace(f'siim/dicom-images-train/{image.parts[-1]}')

Project setup#

Imports#

[ ]:
import itertools
from ipywidgets import interact, fixed, IntSlider

import numpy as np
import pandas as pd
import torch
from torch import nn
import matplotlib.pyplot as plt
import cv2

# model
import segmentation_models_pytorch as smp

# data
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import draw_segmentation_masks, make_grid
from pydicom.filereader import dcmread
from sklearn.model_selection import StratifiedKFold

# transforms
from albumentations import ShiftScaleRotate, Resize, Compose

from torchmetrics import Metric
from torchmetrics.collections import MetricCollection

# composer
from composer import Trainer
from composer.models import ComposerModel
from composer.optim import DecoupledAdamW
from composer.metrics.metrics import Dice

Utils#

Here we define some utility functions to help with logging, decoding/encoding targets, and visualization.

[ ]:
class LossMetric(Metric):
    """Turns any torch.nn Loss Module into distributed torchmetrics Metric."""

    def __init__(self, loss, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.loss = loss
        self.add_state("sum_loss", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("total_batches", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds, target):
        """Update the state with new predictions and targets.
        """
        # Loss calculated over samples/batch, accumulate loss over all batches
        self.sum_loss += self.loss(preds, target)
        self.total_batches += 1

    def compute(self):
        """Aggregate state over all processes and compute the metric.
        """
        # Return average loss over entire validation dataset
        return self.sum_loss / self.total_batches

def rle2mask(rle, height=1024, width=1024, fill_value=1):
    mask = np.zeros((height, width), np.float32)
    mask = mask.reshape(-1)
    rle = np.array([int(s) for s in rle.strip().split(' ')])
    rle = rle.reshape(-1, 2)
    start = 0
    for index, length in rle:
        start = start+index
        end = start+length
        mask[start: end] = fill_value
        start = end
    mask = mask.reshape(width, height).T
    return mask

def mask2rle(mask):
    mask = mask.T.flatten()
    start = np.where(mask[1:] > mask[:-1])[0]+1
    end = np.where(mask[:-1] > mask[1:])[0]+1
    length = end-start
    rle = []
    for i in range(len(length)):
        if i == 0:
            rle.extend([start[0], length[0]])
        else:
            rle.extend([start[i]-end[i-1], length[i]])
    rle = ' '.join([str(r) for r in rle])
    return rle

Preprocessing and Data Science#

SIIM Dataset#

The SIIM dataset consists of: - dicom-images-train - 12954 labeled images in DICOM format. - dicom-images-test - 3205 unlabeled DICOM images for testing

  • train-rle.csv comes with a label file train-rle.csv mapping ImageId to EncodedPixels.

    • ImageIds map to image paths for DICOM format images.

  • EncodedPixels are run length encoded segmentation masks representing areas where pneumothorax has been labeled by an expert. A label of "-1" indicates the image was examined and no pneumothorax was found.

[ ]:
!ls siim
[ ]:
labels_df = pd.read_csv('siim/train-rle.csv')
labels_df.shape

Clean Data#

Of the ~13,000 images, only 3600 have masks. We will throw out some of the negative samples to better balance our dataset and speed up training.

[ ]:
labels_df[labels_df[" EncodedPixels"] != "-1"].shape, labels_df[labels_df[" EncodedPixels"] == "-1"].shape
[ ]:
def balance_labels(labels_df, extra_samples_without_mask=1500, random_state=1337):
    """
    Drop duplicates and mark samples with masks.
    Sample 3576+extra_samples_without_mask unmasked samples to balance dataset.
    """
    df = labels_df.drop_duplicates('ImageId')
    df_with_mask = df[df[" EncodedPixels"] != "-1"].copy(deep=True)
    df_with_mask['has_mask'] = 1
    df_without_mask = df[df[" EncodedPixels"] == "-1"].copy(deep=True)
    df_without_mask['has_mask'] = 0
    df_without_mask_sampled = df_without_mask.sample(len(df_with_mask)+extra_samples_without_mask, random_state=random_state)
    df = pd.concat([df_with_mask, df_without_mask_sampled])
    return df
[ ]:
df = balance_labels(labels_df)
df.shape

Create Cross Validation Splits#

Once cleaned and balanced, weโ€™re left with only 6838 images. This will leave us with rather small training and validation sets once we split the data. To mitigate the chances of us validating on a poorly sampled (not representative of our unlabeled test data) validation set, we use StratifiedKFold to create 5 different 80%-20%, train eval splits.

Note: For datasets of this size, itโ€™s good practice to train and evaluate on each split, but due to runtime constraints in this notebook we will only train on the first split which contains 5470 training and 1368 eval samples.

[ ]:
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=1337)
train_idx, eval_idx = list(kfold.split(df["ImageId"], df["has_mask"]))[0]
train_df, eval_df = df.iloc[train_idx], df.iloc[eval_idx]
train_df.shape, eval_df.shape

PyTorch#

PyTorch Dataset#

SIIMDataset is a standard PyTorch dataset that reads images and decodes labels from the siim label csv. DICOM images are loaded as grayscale numpy arrays, converted to rgb, and scaled. Labels are converted from rle strings to binary segmentation masks.

[ ]:
class SIIMDataset(Dataset):
    def __init__(self,
                 labels_df,
                 transforms=None,
                 image_dir=Path('siim/dicom-images-train')):
        self.labels_df = labels_df
        self.image_dir = image_dir
        self.transforms = transforms

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        image_id = row.ImageId
        image_path = self.image_dir / f'{image_id}.dcm'
        image = dcmread(image_path).pixel_array # load dicom image
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # convert rgb so we can keep imagenet first layer weights
        image = (image / 255.).astype('float32') # scale (0.- 1.)

        rle = row[' EncodedPixels']
        if rle != '-1':
            mask = rle2mask(rle, 1024, 1024).astype('float32')
        else:
            mask = np.zeros([1024, 1024]).astype('float32')

        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return (
            torch.from_numpy(image).permute(2, 0, 1),
            torch.from_numpy(mask).unsqueeze(0)
        )

    def __len__(self):
        return len(self.labels_df)

Transforms#

We use the albumentations library to resize and randomly scale/rotate our training images.

[ ]:
image_size = 512

train_transforms = Compose(
    [
        Resize(image_size, image_size),
        ShiftScaleRotate(
            shift_limit=0,
            scale_limit=0.1,
            rotate_limit=10, # rotate
            p=0.5,
            border_mode=cv2.BORDER_CONSTANT
        )
    ]
)

eval_transforms = Compose([Resize(image_size, image_size)])

DataLoaders#

[ ]:

train_batch_size = 32 val_batch_size = 32 train_dataloader = DataLoader(SIIMDataset(train_df, transforms=train_transforms), batch_size=train_batch_size, shuffle=True, num_workers=2) eval_dataloader = DataLoader(SIIMDataset(eval_df, transforms=eval_transforms), batch_size=val_batch_size, shuffle=False, num_workers=2)

Visualize batch#

Areas of pneumothorax are highlighted in red; drag the slider to iterate through batches.

[ ]:
@interact(data_loader=fixed(train_dataloader), batch=IntSlider(min=0, max=len(train_dataloader)-1, step=1, value=0))
def show_batch(data_loader, batch):
    plt.rcParams['figure.figsize'] = [20, 15]

    images, masks = list(itertools.islice(data_loader, batch, batch+1))[0]
    masks_list = []
    for image, mask in zip(images, masks):
        masked = draw_segmentation_masks((image * 255).byte(),
                                    mask.bool(), alpha=0.5, colors='red')
        masks_list.append(masked)

    grid  = make_grid(masks_list, nrow=6)
    plt.imshow(grid.permute(1, 2, 0));

Composer#

Model#

Here we define a Composer model that wraps the smp segmentation models pytorch package. This lets us quickly create many different segmentation models made from common pre-trained PyTorch encoders.

  • We set defaults to create a Unet from an ImageNet pre-trained ResNet-34 with 3 input channels for our RGB (converted) inputs and 1 output channel.

  • We set the default loss to nn.BCEWithLogitsLoss() to classify each pixel of the output.

[ ]:
class SMPUNet(ComposerModel):
    def __init__(self,
                 encoder_name='resnet34',
                 encoder_weights='imagenet',
                 in_channels=3, classes=1,
                 loss=nn.BCEWithLogitsLoss()):
        super().__init__()
        self.model = smp.Unet(
            encoder_name=encoder_name,
            encoder_weights=encoder_weights,     # use `imagenet` pre-trained weights for encoder initialization
            in_channels=in_channels,        # model input channels (1 for gray-scale images, 3 for RGB, etc.)
            classes=classes         # model output channels (number of classes in your dataset)
        )

        self.criterion = loss
        self.train_loss = LossMetric(loss)
        self.val_loss = LossMetric(loss)
        self.val_dice = Dice(num_classes=classes)

    def forward(self, batch):
        images, targets = batch
        return self.model(images)

    def loss(self, outputs, batch):
        _, targets = batch
        return self.criterion(outputs, targets)

    def get_metrics(self, is_train: bool = False):
        if self.is_train:
            return {'BCEWithLogitsLoss', self.train_loss}
        else:
            return {'BCEWithLogitsLoss': self.val_loss, 'Dice': self.dice}
[ ]:
model = SMPUNet() # define unet model
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)

Trainer#

[ ]:
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration='2ep',
    optimizers=optimizer,
    device='gpu',
    precision='amp',
    seed=1337
)
trainer.fit()

Algorithms#

Composer allows us to quickly experiment with algorithms that can speed up or improve the quality of our model. This is how we can add CutOut and LabelSmoothing

Additionally, the Composer trainer has builtin support for automatic mixed precision training and gradient accumulation to help train quickly and simulate larger batch sizes.

[ ]:
from composer.algorithms import CutOut, LabelSmoothing

model = SMPUNet() # define unet model
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)

algorithms = [CutOut(length=0.5), LabelSmoothing(smoothing=0.1)]

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    max_duration='2ep',
    optimizers=optimizer,
    algorithms=algorithms,
    device='gpu',
    precision='amp',
    seed=1337
)
trainer.fit()

What next?#

Youโ€™ve now seen a from-scratch demonstration of using Composer in a computer vision project. But donโ€™t stop here! If youโ€™re interested, we recommend that you continue to experiment with:

  • training longer

  • different loss functions, architectures, transformations, and

  • different combinations of composer methods!

In addition, please continue to explore our tutorials! Here are a couple suggestions:

Come get involved with MosaicML!#

Weโ€™d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature youโ€™d like? File an issue or make a pull request!