Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

🎢 FaceSynthetics with Streaming Dataloader#

Why wait for your data to download when you can stream it instead? Let’s see how to do so with Composer.

Streaming is useful for multi-node setups where workers don’t have persistent storage and each element of the dataset must be downloaded exactly once. In this tutorial, we’ll demonstrate a streaming approach to loading our datasets, using Microsoft’s FaceSynthetics dataset as an example.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to showcase how to prepare and use Composer’s streaming data loading tools. It will consist of a few steps:

  1. Obtaining the dataset

  2. Preparing the dataset for streaming

    1. (Optionally) uploading the dataset to a server

  3. Streaming the dataset to the local machine

  4. Training a model using these datasets

Let’s get started!

Setup#

Let’s start by making sure the right packages are installed and imported.

First, let’s make sure we’ve installed our dependencies; note that mmcv-full will take some time to unpack. To speed things up, we have included mmcv, mmsegmentation and many other useful computer vision libraries in the mosaicml/pytorch_vision Docker Image.

[ ]:
%pip install mmsegmentation "mmcv-full==1.5.0"

%pip install mosaicml
# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install git+https://github.com/mosaicml/composer.git
[ ]:
import os
import time
import torch
import struct
import shutil
import requests

from PIL import Image
from io import BytesIO
from zipfile import ZipFile
from torch.utils.data import DataLoader
from typing import Iterator, Tuple, Dict
from torchvision import transforms as tf

We’ll be using Composer’s streaming dataset writer, as well as the Composer DeepLabV3 model, which should help improve our performance even on the small, hundred-image dataset.

[ ]:
from composer.datasets.streaming import StreamingDatasetWriter, StreamingDataset
from composer.models.deeplabv3 import composer_deeplabv3
[ ]:
from composer import Trainer
from composer.models import composer_deeplabv3
from composer.optim import DecoupledAdamW

Global settings#

For this tutorial, it makes the most sense to organize our global settings here rather than distribute them throughout the cells in which they’re used.

[ ]:
# the location of our dataset
in_root = "./dataset"

# the location of the "remote" streaming dataset.
# Upload `out_root` to your cloud storage provider of choice.
out_root = "./sdl"
out_train = "./sdl/train"
out_test = "./sdl/test"

# the location to download the streaming dataset during training
local = './local'
local_train = './local/train'
local_test = './local/test'

# toggle shuffling in dataloader
shuffle_train = True
shuffle_test = False

# possible values for a pixel in the annotation image to take
num_classes = 20

# shard size limit, in bytes
shard_size_limit = 1 << 25

# show a progress bar while downloading
use_tqdm = True

# ratio of training data to test data
training_ratio = 0.9

# training batch size
batch_size = 2 # this is the smallest batch size possible,
               # increase this if your machine can handle it.

# training hardware parameters
device = "gpu" if torch.cuda.is_available() else "cpu"

# number of training epochs
train_epochs = "3ep" # increase the number of epochs for greater accuracy

# number of images in the dataset (training + test)
num_images = 100 # can be 100, 1_000, or 100_000

# location to download the dataset zip file
dataset_archive = "./dataset.zip"

# remote dataset URL
URL = f"https://facesyntheticspubwedata.blob.core.windows.net/iccv-2021/dataset_{num_images}.zip"
[ ]:

# upload location for the dataset splits (change this if you want to upload to a different location) upload_train_location = None upload_test_location = None

Getting the dataset#

[ ]:
if not os.path.exists(dataset_archive):
    response = requests.get(URL)
    with open(dataset_archive, "wb") as dataset_file:
        dataset_file.write(response.content)

    with ZipFile(dataset_archive, 'r') as myzip:
        myzip.extractall(in_root)

Next, we’ll make the directories for our binary streaming dataset files.

Preparing the dataset#

The dataset consists of a directory of images with names in the form 123456.png, 123456_seg.png, and 123456_ldmks.png. For this example, we’ll only use the images with segmentation annotations as labels and ignore the landmarks for now.

[ ]:
def each(dirname: str, start_ix: int = 0, end_ix: int = num_images) -> Iterator[Dict[str, bytes]]:
    for i in range(start_ix, end_ix):
        image = '%s/%06d.png' % (dirname, i)
        annotation = '%s/%06d_seg.png' % (dirname, i)

        with open(image, 'rb') as x, open(annotation, 'rb') as y:
            yield {
                'i': struct.pack('>q', i),
                'x': x.read(),
                'y': y.read(),
            }

Below, we’ll set up the logic for writing our starting dataset to files that can be read using a streaming dataloader.

For more information on the StreamingDatasetWriter check out the API reference.

[ ]:
def write_datasets() -> None:
    os.makedirs(out_train, exist_ok=True)
    os.makedirs(out_test, exist_ok=True)

    fields = ['i', 'x', 'y']

    num_training_images = int(num_images * training_ratio)

    start_ix, end_ix = 0, num_training_images
    with StreamingDatasetWriter(out_train, fields, shard_size_limit, remote=upload_train_location) as out:
        out.write_samples(each(in_root, start_ix, end_ix),
                          use_tqdm=use_tqdm,
                          total=end_ix-start_ix)
    start_ix, end_ix = end_ix, num_images
    with StreamingDatasetWriter(out_test, fields, shard_size_limit, remote=upload_test_location) as out:
        out.write_samples(each(in_root, start_ix, end_ix),
                          use_tqdm=use_tqdm,
                          total=end_ix-start_ix)

Now that we’ve written the datasets to out_root, and (optionally) uploaded them to a cloud storage provider, we are ready to stream them.

[ ]:
remote_train = out_train if upload_train_location is None else upload_train_location # replace this with your URL for cloud streaming
remote_test  = out_test if upload_test_location is None else upload_test_location

Loading the Data#

We extend composer’s StreamingDataset to deserialize the binary data and convert the labels to one-hot encoding.

For more information on the StreamingDataset parent class check out the API reference.

[ ]:
class FaceSynthetics(StreamingDataset):
    def __init__(self,
                 remote: str,
                 local: str,
                 shuffle: bool,
                 batch_size: int,
                ) -> None:
        decoders = {
            'i': lambda data: struct.unpack('>q', data),
            'x': lambda data: Image.open(BytesIO(data)),
            'y': lambda data: Image.open(BytesIO(data)),
        }
        super().__init__(local=local, remote=remote, shuffle=shuffle, decoders=decoders, batch_size=batch_size)

    def __getitem__(self, i:int) -> Tuple[torch.Tensor, torch.Tensor]:
        obj = super().__getitem__(i)
        x = tf.functional.to_tensor(obj['x'])
        y = tf.functional.pil_to_tensor(obj['y'])[0].to(torch.int64)
        y[y == 255] = 19
        return x, y

Putting It All Together#

We’re now ready to actually write the streamable dataset. Let’s do that if we haven’t already.

[ ]:
if not os.path.exists(out_train):
    write_datasets()

Once that’s done, we can instantiate our streaming datasets and wrap them in standard dataloaders for training!

[ ]:
dataset_train = FaceSynthetics(remote_train, local_train, shuffle_train, batch_size=batch_size)
dataset_test  = FaceSynthetics(remote_test, local_test, shuffle_test, batch_size=batch_size)

train_dataloader = DataLoader(dataset_train, batch_size=batch_size)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size)

Train with the Streaming Dataloaders#

Now all that’s left to do is train! Doing so with Composer should look pretty familiar by now.

[ ]:
# Create a DeepLabV3 model, and an optimizer for it
model = composer_deeplabv3(
    num_classes=num_classes,
    backbone_arch='resnet101',
    backbone_weights='IMAGENET1K_V2',
    sync_bn=False)
optimizer = DecoupledAdamW(model.parameters(), lr=1e-3)

# Create a trainer object without our model, optimizer, and streaming dataloaders
trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    max_duration=train_epochs,
    optimizers=optimizer,
    device=device
)

# Train!
start_time = time.perf_counter()
trainer.fit()
end_time = time.perf_counter()
print(f"It took {end_time - start_time:0.4f} seconds to train")

Cleanup#

That’s it. No need to hang on to the files created by the tutorial…

[ ]:
shutil.rmtree(out_root, ignore_errors=True)
shutil.rmtree(in_root, ignore_errors=True)
if os.path.exists(dataset_archive):
    os.remove(dataset_archive)

What next?#

You’ve now seen an in-depth look at how to prepare and use steaming datasets with Composer.

To continue learning about Composer, please continue to explore our tutorials! Here are a couple suggestions:

Come get involved with MosaicML!#

We’d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature you’d like? File an issue or make a pull request!