๐ฉบ Image Segmentation#
In this notebook you will use Composer and PyTorch to segment pneumothorax (air around or outside of the lungs) from chest radiographic images. This dataset was originally released for a kaggle competition by the Society for Informatics in Medicine (SIIM).
Disclaimer: This example represents a minimal working baseline. In order to get competitive results this notebook must run for a long time.
We will cover: - installing relevant packages - downloading the SIIM dataset from kaggle - cleaning and resampling the dataset - splitting data for validation - visualizing model inputs - training a baseline model with Composer - using Composer methods - next steps
Setup#
Letโs get started and configure our environment.
Install Dependencies#
If you havenโt already, letโs install the following dependencies, which are needed for this example:
[ ]:
%pip install mosaicml kaggle pydicom git+https://github.com/qubvel/segmentation_models.pytorch opencv-python-headless jupyterlab-widgets
Kaggle Authentication#
To access the data you need a Kaggle Account - accept competition terms https://www.kaggle.com/c/siim-acr-pneumothorax-segmentation/data - download kaggle.json from https://www.kaggle.com/yourusername/account by clicking โCreate new API tokenโ - upload kaggle.json
file using the following code cells.
[ ]:
from ipywidgets import FileUpload
from IPython.display import display
uploader = FileUpload(accept='.json', multiple=True)
display(uploader)
[ ]:
import os
kaggle_folder = os.path.join(os.path.expanduser("~"), ".kaggle")
os.makedirs(kaggle_folder, exist_ok=True)
kaggle_config_file = os.path.join(kaggle_folder, "kaggle.json")
with open(kaggle_config_file, 'wb+') as output_file:
for uploaded_filename in uploader.value:
content = uploader.value[uploaded_filename]['content']
output_file.write(content)
Download and unzip the data#
[ ]:
!kaggle datasets download -d seesee/siim-train-test
!unzip -q siim-train-test.zip -d .
!ls
Flatten Image Directories#
The original dataset is oddly nested, we flatten it out so the images are easier to access in our pytorch dataset.
/siim/dicom-images-train/id/id/id.dcm
to /siim/dicom-images-train/id.dcm
.
[ ]:
from pathlib import Path
from tqdm.auto import tqdm
train_images = list(Path('siim/dicom-images-train').glob('*/*/*.dcm'))
for image in tqdm(train_images):
image.replace(f'siim/dicom-images-train/{image.parts[-1]}')
Project setup#
Imports#
[ ]:
import itertools
from ipywidgets import interact, fixed, IntSlider
import numpy as np
import pandas as pd
import torch
from torch import nn
import matplotlib.pyplot as plt
import cv2
from pydicom.filereader import dcmread
# model
import segmentation_models_pytorch as smp
# data
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import draw_segmentation_masks, make_grid
from sklearn.model_selection import StratifiedKFold
# transforms
from albumentations import ShiftScaleRotate, Resize, Compose
from torchmetrics import Metric
from torchmetrics.collections import MetricCollection
from composer import Trainer
from composer.models import ComposerModel
from composer.optim import DecoupledAdamW
from composer.metrics.metrics import Dice
Utils#
Here we define some utility functions to help with logging, decoding/encoding targets and visualization.
[ ]:
class LossMetric(Metric):
"""Turns any torch.nn Loss Module into distributed torchmetrics Metric."""
def __init__(self, loss, dist_sync_on_step=False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.loss = loss
self.add_state("sum_loss", default=torch.tensor(0.), dist_reduce_fx="sum")
self.add_state("total_batches", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds, target):
"""Update the state with new predictions and targets.
"""
# Loss calculated over samples/batch, accumulate loss over all batches
self.sum_loss += self.loss(preds, target)
self.total_batches += 1
def compute(self):
"""Aggregate state over all processes and compute the metric.
"""
# Return average loss over entire validation dataset
return self.sum_loss / self.total_batches
def rle2mask(rle, height=1024, width=1024, fill_value=1):
mask = np.zeros((height, width), np.float32)
mask = mask.reshape(-1)
rle = np.array([int(s) for s in rle.strip().split(' ')])
rle = rle.reshape(-1, 2)
start = 0
for index, length in rle:
start = start+index
end = start+length
mask[start: end] = fill_value
start = end
mask = mask.reshape(width, height).T
return mask
def mask2rle(mask):
mask = mask.T.flatten()
start = np.where(mask[1:] > mask[:-1])[0]+1
end = np.where(mask[:-1] > mask[1:])[0]+1
length = end-start
rle = []
for i in range(len(length)):
if i == 0:
rle.extend([start[0], length[0]])
else:
rle.extend([start[i]-end[i-1], length[i]])
rle = ' '.join([str(r) for r in rle])
return rle
Preprocessing and Data Science#
SIIM Dataset#
The SIIM dataset consists of: - dicom-images-train
- 12954 labeled images in DICOM format. - dicom-images-test
3205 unlabeled DICOM images for testing
train-rle.csv
comes with a label filetrain-rle.csv
mappingImageId
toEncodedPixels
.ImageId
s map to image paths for DICOM format images.
EncodedPixels
are run length encoded segmentation masks representing areas where pneumothorax has been labeled by an expert. A label of"-1"
indicates the image was examined and no pneumothorax was found.
[ ]:
!ls siim
[ ]:
labels_df = pd.read_csv('siim/train-rle.csv')
labels_df.shape
Clean Data#
Of the ~13,000, only 3600 have masks. We will throw out some of the negative samples to better balance our dataset and speed up training.
[ ]:
labels_df[labels_df[" EncodedPixels"] != "-1"].shape, labels_df[labels_df[" EncodedPixels"] == "-1"].shape
[ ]:
def balance_labels(labels_df, extra_samples_without_mask=1500, random_state=1337):
"""
Drop duplicates and mark samples with masks.
Sample 3576+extra_samples_without_mask unmasked samples to balance dataset.
"""
df = labels_df.drop_duplicates('ImageId')
df_with_mask = df[df[" EncodedPixels"] != "-1"].copy(deep=True)
df_with_mask['has_mask'] = 1
df_without_mask = df[df[" EncodedPixels"] == "-1"].copy(deep=True)
df_without_mask['has_mask'] = 0
df_without_mask_sampled = df_without_mask.sample(len(df_with_mask)+extra_samples_without_mask, random_state=random_state)
df = pd.concat([df_with_mask, df_without_mask_sampled])
return df
[ ]:
df = balance_labels(labels_df)
df.shape
Create Cross Validation Splits#
Once cleaned and balanced, weโre left with only 6838 images. This will leave us with rather small training and validation sets once we split the data. To mitigate the chances of us validating on a poorly sampled (not representative of our unlabeled test data) validation set, we use StratifiedKFold to create 5 different 80%-20%, train
eval
splits.
Disclaimer:
For datasets of this size, itโs good practice to train and evaluate on each split, but due to runtime constraints in this notebook we will only train on the first split which contains 5470 training and 1368 eval samples.
[ ]:
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=1337)
train_idx, eval_idx = list(kfold.split(df["ImageId"], df["has_mask"]))[0]
train_df, eval_df = df.iloc[train_idx], df.iloc[eval_idx]
train_df.shape, eval_df.shape
PyTorch#
PyTorch Dataset#
SIIMDataset
is a standard pytorch dataset that reads images and decodes labels from the siim label csv. DICOM images are loaded as grayscale numpy arrays, converted to rgb, and scaled. Labels are converted from rle strings to binary segmentation masks.
[ ]:
class SIIMDataset(Dataset):
def __init__(self,
labels_df,
transforms=None,
image_dir=Path('siim/dicom-images-train')):
self.labels_df = labels_df
self.image_dir = image_dir
self.transforms = transforms
def __getitem__(self, idx):
row = self.labels_df.iloc[idx]
image_id = row.ImageId
image_path = self.image_dir / f'{image_id}.dcm'
image = dcmread(image_path).pixel_array # load dicom image
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) # convert rgb so we can keep imagenet first layer weights
image = (image / 255.).astype('float32') # scale (0.- 1.)
rle = row[' EncodedPixels']
if rle != '-1':
mask = rle2mask(rle, 1024, 1024).astype('float32')
else:
mask = np.zeros([1024, 1024]).astype('float32')
if self.transforms:
augmented = self.transforms(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
return (
torch.from_numpy(image).permute(2, 0, 1),
torch.from_numpy(mask).unsqueeze(0)
)
def __len__(self):
return len(self.labels_df)
Transforms#
We use the albumentations library to Resize, and randomly scale/rotate our training images.
[ ]:
image_size = 512
train_transforms = Compose(
[
Resize(image_size, image_size),
ShiftScaleRotate(
shift_limit=0,
scale_limit=0.1,
rotate_limit=10, # rotate
p=0.5,
border_mode=cv2.BORDER_CONSTANT
)
]
)
eval_transforms = Compose([Resize(image_size, image_size)])
DataLoaders#
[ ]:
train_batch_size = 32
val_batch_size = 32
train_dataloader = DataLoader(SIIMDataset(train_df, transforms=train_transforms),
batch_size=train_batch_size, shuffle=True, num_workers=2)
eval_dataloader = DataLoader(SIIMDataset(eval_df, transforms=eval_transforms),
batch_size=val_batch_size, shuffle=False, num_workers=2)
Visualize batch#
Areas of pneumothorax as highlighted in red, drag the slider to iterate through batches.
[ ]:
@interact(data_loader=fixed(train_dataloader), batch=IntSlider(min=0, max=len(train_dataloader)-1, step=1, value=0))
def show_batch(data_loader, batch):
plt.rcParams['figure.figsize'] = [20, 15]
images, masks = list(itertools.islice(data_loader, batch, batch+1))[0]
masks_list = []
for image, mask in zip(images, masks):
masked = draw_segmentation_masks((image * 255).byte(),
mask.bool(), alpha=0.5, colors='red')
masks_list.append(masked)
grid = make_grid(masks_list, nrow=6)
plt.imshow(grid.permute(1, 2, 0));
Composer#
Model#
Here we define a composer model that wraps the smp segmentation models pytorch package. This lets us quickly create many different segmentation models made from common pre-trained pytorch encoders.
We set defaults to create a Unet from an ImageNet pre-trained ResNet34 with 3 input channels for our RGB (converted) inputs and 1 output channel.
We set the default loss to
nn.BCEWithLogitsLoss()
to classify each pixel of the output.
[ ]:
class SMPUNet(ComposerModel):
def __init__(self,
encoder_name='resnet34',
encoder_weights='imagenet',
in_channels=3, classes=1,
loss=nn.BCEWithLogitsLoss()):
super().__init__()
self.model = smp.Unet(
encoder_name=encoder_name,
encoder_weights=encoder_weights, # use `imagenet` pre-trained weights for encoder initialization
in_channels=in_channels, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=classes # model output channels (number of classes in your dataset)
)
self.criterion = loss
self.train_loss = LossMetric(loss)
self.val_loss = LossMetric(loss)
self.val_dice = Dice(num_classes=classes)
def forward(self, batch):
images, targets = batch
return self.model(images)
def loss(self, outputs, batch):
_, targets = batch
return self.criterion(outputs, targets)
def metrics(self, train: bool = False):
return self.train_loss if self.train else MetricCollection([self.val_loss, self.dice])
def validate(self, batch):
images, targets = batch
return self.model(images), targets
[ ]:
model = SMPUNet() # define unet model
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)
Trainer#
[ ]:
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration='2ep',
optimizers=optimizer,
device='gpu',
precision='amp',
seed=1337
)
trainer.fit()
Methods#
composer allows us to quickly experiment with algorithms that can speed up or improve the quality of our model. This is how we can add
CutOut
andLabelSmoothing
additionally, the composer trainer has builtin support for automatic mixed precision training and gradient accumulation to help train quickly and simulate larger batch sizes.
[ ]:
from composer.algorithms import CutOut, LabelSmoothing
model = SMPUNet() # define unet model
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)
algorithms = [CutOut(length=0.5), LabelSmoothing(smoothing=0.1)]
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration='2ep',
optimizers=optimizer,
algorithms=algorithms,
device='gpu',
precision='amp',
seed=1337
)
trainer.fit()
Next steps#
train longer
try different loss functions, architectures, transformations
try different combinations of composer methods!