DataSpec#

class composer.DataSpec(dataloader, num_samples=None, num_tokens=None, device_transforms=None, split_batch=None, get_num_samples_in_batch=None, get_num_tokens_in_batch=None)[source]#

Specifications for operating and training on data.

An example of constructing a DataSpec object with a device_transforms callable (NormalizationFn) and then using it with Trainer:

>>> # In this case, we apply NormalizationFn
>>> # Construct DataSpec as shown below to apply this transformation
>>> from composer.datasets.utils import NormalizationFn
>>> CHANNEL_MEAN = (0.485 * 255, 0.456 * 255, 0.406 * 255)
>>> CHANNEL_STD = (0.229 * 255, 0.224 * 255, 0.225 * 255)
>>> device_transform_fn = NormalizationFn(mean=CHANNEL_MEAN, std=CHANNEL_STD)
>>> train_dspec = DataSpec(train_dataloader, device_transforms=device_transform_fn)
>>> # The same function can be used for eval dataloader as well
>>> eval_dspec = DataSpec(eval_dataloader, device_transforms=device_transform_fn)
>>> # Use this DataSpec object to construct trainer
>>> trainer = Trainer(
...     model=model,
...     train_dataloader=train_dspec,
...     eval_dataloader=eval_dspec,
...     optimizers=optimizer,
...     max_duration="1ep",
... )
Parameters
  • dataloader (Iterable) โ€“ The dataloader, which can be any iterable that yields batches.

  • num_samples (int, optional) โ€“ The total number of samples in an epoch, across all ranks. This field is used by the Timestamp (training progress tracker). If not specified, then len(dataloader.dataset) is used (if this property is available). Otherwise, the dataset is assumed to be unsized.

  • num_tokens (int, optional) โ€“ The total number of tokens in an epoch. This field is used by the Timestamp (training progress tracker).

  • device_transforms ((Batch) -> Batch, optional) โ€“ Function called by the Trainer to modify the batch once it has been moved onto the device. For example, this function can be used for GPU-based normalization. It can modify the batch in-place, and it should return the modified batch. If not specified, the batch is not modified.

  • split_batch ((Batch, int) -> Sequence[Batch], optional) โ€“ Function called by the Trainer to split a batch (the first parameter) into the number of microbatches specified (the second parameter). If the dataloader yields batches not of type torch.Tensor, Mapping, Tuple, or List, then this function must be specified.

  • get_num_samples_in_batch ((Batch) -> int, optional) โ€“

    Function that is called by the Trainer to get the number of samples in the provided batch.

    By default, if the batch contains tensors that all have the same 0th dim, then the value of the 0th dim will be returned. If the batch contains tensors where the 0th dim differ, then this function must be specified.

  • get_num_tokens_in_batch ((Batch) -> int, optional) โ€“

    Function that is called by the Trainer to get the number of tokens in the provided batch.

    By default, it returns 0, meaning that number of tokens processed will not be tracked as a part of the training progress tracking. This function must be specified to track the number of tokens processed during training.