Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐ŸŽข FaceSynthetics with Streaming Dataloader#

In this notebook, weโ€™ll demonstrate a streaming approach to loading our datasets, using Microsoftโ€™s FaceSynthetics dataset as an example.

Streaming is useful for multi-node setups where workers donโ€™t have persistent storage and each element of the dataset must be downloaded exactly once.

This tutorial will consist of a few steps: 1. obtaining the dataset 2. preparing the dataset for streaming a. (optionally) uploading the dataset to a server 3. streaming the dataset to the local machine 4. training a model using these datasets

First, letโ€™s make sure weโ€™ve installed our dependencies, note that mmcv-full will take some time to unpack. To speed things up, we have included mmcv, mmsegmentation and many other useful computer vision libraries in the mosaicml/pytorch_vision Docker Image.

[ ]:
%pip install mosaicml mmsegmentation mmcv mmcv-full
[ ]:
import os
import time
import torch
import struct
import shutil
import requests

from PIL import Image
from io import BytesIO
from zipfile import ZipFile
import torch.utils.data as td
from typing import Iterator, Tuple, Dict
from torchvision import transforms as tf

Weโ€™ll be using Composerโ€™s streaming dataset writer, as well as the composer DeepLabV3 model, which should help improve our performance even on the small hundred image dataset.

[ ]:
from composer.datasets.streaming import StreamingDatasetWriter, StreamingDataset
from composer.models.deeplabv3 import composer_deeplabv3
[ ]:
from composer import Trainer
from composer.models import composer_deeplabv3
from composer.optim import DecoupledAdamW

Global settings#

[ ]:
# the location of our dataset
in_root = "./dataset"

# the location of the "remote" streaming dataset.
# Upload `out_root` to your cloud storage provider of choice.
out_root = "./sdl"
out_train = "./sdl/train"
out_test = "./sdl/test"

# the location to download the streaming dataset during training
local = './local'
local_train = './local/train'
local_test = './local/test'

# toggle shuffling in dataloader
shuffle_train = True
shuffle_test = False

# possible values for a pixel in the annotation image to take
num_classes = 20

# shard size limit, in bytes
shard_size_limit = 1 << 25

# show a progress bar while downloading
use_tqdm = True

# ratio of training data to test data
training_ratio = 0.9

# training batch size
batch_size = 2 # this is the smallest batch size possible,
               # increase this if your machine can handle it.

# training hardware parameters
device = "gpu" if torch.cuda.is_available() else "cpu"

# number of training epochs
train_epochs = "3ep" # increase the number of epochs for greater accuracy

# number of images in the dataset (training + test)
num_images = 100 # can be 100, 1_000, or 100_000

# location to download the dataset zip file
dataset_archive = "./dataset.zip"

# remote dataset URL
URL = f"https://facesyntheticspubwedata.blob.core.windows.net/iccv-2021/dataset_{num_images}.zip"
[ ]:

# upload location for the dataset splits (change this if you want to upload to a different location) upload_train_location = None upload_test_location = None

Getting the dataset#

[ ]:
if not os.path.exists(dataset_archive):
    response = requests.get(URL)
    with open(dataset_archive, "wb") as dataset_file:
        dataset_file.write(response.content)

    with ZipFile(dataset_archive, 'r') as myzip:
        myzip.extractall(in_root)

Next, weโ€™ll make the directories for our binary streaming dataset files.

Preparing the dataset#

The dataset consists of a directory of images with names in the form 123456.png, 123456_seg.png, and 123456_ldmks.png. For this example, weโ€™ll only use the images with segmentation annotations as labels and ignore the landmarks for now.

[ ]:
def each(dirname: str, start_ix: int = 0, end_ix: int = num_images) -> Iterator[Dict[str, bytes]]:
    for i in range(start_ix, end_ix):
        image = '%s/%06d.png' % (dirname, i)
        annotation = '%s/%06d_seg.png' % (dirname, i)

        with open(image, 'rb') as x, open(annotation, 'rb') as y:
            yield {
                'i': struct.pack('>q', i),
                'x': x.read(),
                'y': y.read(),
            }
[ ]:
def write_datasets() -> None:
    os.makedirs(out_train, exist_ok=True)
    os.makedirs(out_test, exist_ok=True)

    fields = ['i', 'x', 'y']

    num_training_images = int(num_images * training_ratio)

    start_ix, end_ix = 0, num_training_images
    with StreamingDatasetWriter(out_train, fields, shard_size_limit, remote=upload_train_location) as out:
        out.write_samples(each(in_root, start_ix, end_ix),
                          use_tqdm=use_tqdm,
                          total=end_ix-start_ix)
    start_ix, end_ix = end_ix, num_images
    with StreamingDatasetWriter(out_test, fields, shard_size_limit, remote=upload_test_location) as out:
        out.write_samples(each(in_root, start_ix, end_ix),
                          use_tqdm=use_tqdm,
                          total=end_ix-start_ix)

Now that weโ€™ve written the datasets to out_root, and (optionally) uploaded them to a cloud storage provider, we are ready to stream them.

[ ]:
remote_train = out_train if upload_train_location is not None else upload_train_location # replace this with your URL for cloud streaming
remote_test  = out_test if upload_test_location is not None else upload_test_location

Loading the Data#

We extend composerโ€™s StreamingDataset to deserialize the binary data and convert the labels to one-hot encoding.

[ ]:
class FaceSynthetics(StreamingDataset):
    def __init__(self,
                 remote: str,
                 local: str,
                 shuffle: bool,
                 batch_size: int,
                ) -> None:
        decoders = {
            'i': lambda data: struct.unpack('>q', data),
            'x': lambda data: Image.open(BytesIO(data)),
            'y': lambda data: Image.open(BytesIO(data)),
        }
        super().__init__(local=local, remote=remote, shuffle=shuffle, decoders=decoders, batch_size=batch_size)

    def __getitem__(self, i:int) -> Tuple[torch.Tensor, torch.Tensor]:
        obj = super().__getitem__(i)
        x = tf.functional.to_tensor(obj['x'])
        y = tf.functional.pil_to_tensor(obj['y'])[0].to(torch.int64)
        y[y == 255] = 19
        return x, y
[ ]:
def get_dataloaders() -> Tuple[td.DataLoader, td.DataLoader]:
    dataset_train = FaceSynthetics(remote_train, local_train, shuffle_train, batch_size=batch_size)
    dataset_test  = FaceSynthetics(remote_test, local_test, shuffle_test, batch_size=batch_size)

    train_dataloader = td.DataLoader(dataset_train, batch_size=batch_size)
    test_dataloader = td.DataLoader(dataset_test, batch_size=batch_size)

    return train_dataloader, test_dataloader

Training the Model#

[ ]:
def make_trainer() -> Trainer:
    train_dataloader, test_dataloader = get_dataloaders()
    model = composer_deeplabv3(
        num_classes=num_classes,
        backbone_arch='resnet101',
        backbone_weights='IMAGENET1K_V2',
        sync_bn=False)
    optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)

    return Trainer(
        model=model,
        train_dataloader=train_dataloader,
        eval_dataloader=test_dataloader,
        max_duration=train_epochs,
        optimizers=optimizer,
        device=device
    )

Putting it all Together#

[ ]:
if not os.path.exists(out_train):
    write_datasets()
[ ]:
trainer = make_trainer()
[ ]:
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
print(f"It took {end_time - start_time:0.4f} seconds to train")

Cleanup#

[ ]:
shutil.rmtree(out_root, ignore_errors=True)
shutil.rmtree(in_root, ignore_errors=True)
if os.path.exists(dataset_archive):
    os.remove(dataset_archive)

Next Steps#

Congrats! Weโ€™ve trained our FaceSynthetics model on a streaming dataset!

Now that weโ€™re done, we can explore some additional speedups and performance improvements, like:

  • training against a full dataset

  • using composerโ€™s suite of speedup algorithms

  • building a multi-gpu trainer for shared streaming

Happy training!