composer.datasets

DataloaderHparams contains the torch.utils.data.dataloader settings that are common across both training and eval datasets:

  • num_workers

  • prefetch_factor

  • persistent_workers

  • pin_memory

  • timeout

Each DatasetHparams is then responsible for returning a DataloaderSpec, which is a NamedTuple of dataset-specific settings such as:

  • dataset

  • drop_last

  • shuffle

  • collate_fn

This indirection (instead of directly creating the dataloader at the start) is needed because for multi-GPU training, dataloaders require the global rank to initialize their torch.utils.data.distributed.DistributedSampler.

As a result, our trainer uses the DataloaderSpec and DataloaderHparams to create the dataloaders after DDP has forked the processes.

Base Classes and Hyperparameters

DataloaderHparams

Hyperparameters to initialize a torch.utils.data.Dataloader.

DataloaderSpec

Specification for initializing a dataloader.

DatasetHparams

Abstract base class for hyperparameters to initialize a dataset.

Datasets

MNISTDatasetHparams

Defines an instance of the MNIST dataset for image classification.

CIFAR10DatasetHparams

Defines an instance of the CIFAR-10 dataset for image classification.

ImagenetDatasetHparams

Defines an instance of the ImageNet dataset for image classification.

LMDatasetHparams

Defines a generic dataset class for autoregressive language models.

SyntheticDatasetHparams

Defines an instance of a synthetic dataset for classification.

BratsDatasetHparams

Defines an instance of the BraTS dataset for image segmentation.