composer.datasets.streaming.dataset#

The StreamingDataset class, used for building streaming iterable datasets.

Classes

StreamingDataset

A sharded, streaming, iterable dataset.

class composer.datasets.streaming.dataset.StreamingDataset(remote, local, shuffle, decoders, max_retries=2, timeout=60, batch_size=None)[source]#

Bases: torch.utils.data.dataset.IterableDataset

A sharded, streaming, iterable dataset.

Features:

  • StreamingDataset reads samples from binary .mds files that were written out by StreamingDatasetWriter.

  • Supports downloading data from S3, SFTP, or local filesystem.

  • Supports multi-gpu and multi-node training, with smart local caching to minimize network bandwidth.

  • Also provides best-effort shuffling to preserve randomness when shuffle=True.

When batch_size is provided, worker indices will be constructed so that there is at most one incomplete batch at the end of each epoch. For example, if the DataLoader is reading over:

samples: [0, 1, 2, 3, 4, 5, 6, 7]
num_workers: 3
batch_size: 2
drop_last: True

but batch_size is not hinted to the StreamingDataset ahead of time, then the samples will by default be assigned like:

worker 0: [0, 1, 2]
worker 1: [3, 4, 5]
worker 2: [6, 7]

and will be read as batches like (with samples [2] and [5] dropped as incomplete):

batch 0: [0, 1]
batch 1: [3, 4]
batch 2: [6, 7]

The above is suboptimal because we could have dropped no samples. So when batch_size is provided as a hint, we assign samples like this:

worker 0: [0, 1, 2, 3]
worker 1: [4, 5]
worker 2: [6, 7]

which will be read as batches like:

batch 0: [0, 1]
batch 1: [4, 5]
batch 2: [6, 7]
batch 3: [2, 3]
Parameters
  • remote (Optional[str]) โ€“ Download shards from this remote path or directory.

  • local (str) โ€“ Download shards to this local directory for for caching.

  • shuffle (bool) โ€“ Whether to shuffle the samples. Note that if shuffle=False, the sample order is deterministic but dependent on the DataLoaderโ€™s num_workers.

  • decoders (Dict[str, Callable[bytes, Any]]]) โ€“ For each sample field you wish to read, you must provide a decoder to convert the raw bytes to an object.

  • max_retries (int) โ€“ Number of download re-attempts before giving up. Default: 2.

  • timeout (float) โ€“ How long to wait for shard to download before raising an exception. Default: 60 sec.

  • batch_size (Optional[int]) โ€“ Hint the batch_size that will be used on each deviceโ€™s DataLoader. Default: None.

To write the dataset:
>>> from composer.datasets.streaming import StreamingDatasetWriter
>>> samples = [
...     {
...         "uid": f"{ix:06}".encode("utf-8"),
...         "data": (3 * ix).to_bytes(4, "big"),
...         "unused": "blah".encode("utf-8"),
...     }
...     for ix in range(100)
... ]
>>> dirname = "remote"
>>> fields = ["uid", "data"]
>>> with StreamingDatasetWriter(dirname=dirname, fields=fields) as writer:
...     writer.write_samples(samples=samples)

To read the dataset:
>>> from composer.datasets.streaming import StreamingDataset
>>> remote = "remote"
>>> local = "local"
>>> decoders = {
...     "uid": lambda uid_bytes: uid_bytes.decode("utf-8"),
...     "data": lambda data_bytes: int.from_bytes(data_bytes, "big"),
... }
>>> dataset = StreamingDataset(remote=remote, local=local, shuffle=False, decoders=decoders)
download()[source]#

Download and assimilate missing shards.