composer.datasets.streaming.dataset#
The StreamingDataset
class, used for building streaming iterable datasets.
Classes
A sharded, streaming, iterable dataset. |
- class composer.datasets.streaming.dataset.StreamingDataset(remote, local, shuffle, decoders, max_retries=2, timeout=60, batch_size=None)[source]#
Bases:
torch.utils.data.dataset.IterableDataset
A sharded, streaming, iterable dataset.
Features:
StreamingDataset
reads samples from binary.mds
files that were written out byStreamingDatasetWriter
.Supports downloading data from S3, SFTP, or local filesystem.
Supports multi-gpu and multi-node training, with smart local caching to minimize network bandwidth.
Also provides best-effort shuffling to preserve randomness when
shuffle=True
.
When
batch_size
is provided, worker indices will be constructed so that there is at most one incomplete batch at the end of each epoch. For example, if the DataLoader is reading over:samples: [0, 1, 2, 3, 4, 5, 6, 7] num_workers: 3 batch_size: 2 drop_last: True
but
batch_size
is not hinted to the StreamingDataset ahead of time, then the samples will by default be assigned like:worker 0: [0, 1, 2] worker 1: [3, 4, 5] worker 2: [6, 7]
and will be read as batches like (with samples [2] and [5] dropped as incomplete):
batch 0: [0, 1] batch 1: [3, 4] batch 2: [6, 7]
The above is suboptimal because we could have dropped no samples. So when
batch_size
is provided as a hint, we assign samples like this:worker 0: [0, 1, 2, 3] worker 1: [4, 5] worker 2: [6, 7]
which will be read as batches like:
batch 0: [0, 1] batch 1: [4, 5] batch 2: [6, 7] batch 3: [2, 3]
- Parameters
remote (Optional[str]) โ Download shards from this remote path or directory.
local (str) โ Download shards to this local directory for for caching.
shuffle (bool) โ Whether to shuffle the samples. Note that if
shuffle=False
, the sample order is deterministic but dependent on the DataLoaderโsnum_workers
.decoders (Dict[str, Callable[bytes, Any]]]) โ For each sample field you wish to read, you must provide a decoder to convert the raw bytes to an object.
max_retries (int) โ Number of download re-attempts before giving up. Default: 2.
timeout (float) โ How long to wait for shard to download before raising an exception. Default: 60 sec.
batch_size (Optional[int]) โ Hint the batch_size that will be used on each deviceโs DataLoader. Default:
None
.
To write the dataset: >>> from composer.datasets.streaming import StreamingDatasetWriter >>> samples = [ ... { ... "uid": f"{ix:06}".encode("utf-8"), ... "data": (3 * ix).to_bytes(4, "big"), ... "unused": "blah".encode("utf-8"), ... } ... for ix in range(100) ... ] >>> dirname = "remote" >>> fields = ["uid", "data"] >>> with StreamingDatasetWriter(dirname=dirname, fields=fields) as writer: ... writer.write_samples(samples=samples) To read the dataset: >>> from composer.datasets.streaming import StreamingDataset >>> remote = "remote" >>> local = "local" >>> decoders = { ... "uid": lambda uid_bytes: uid_bytes.decode("utf-8"), ... "data": lambda data_bytes: int.from_bytes(data_bytes, "big"), ... } >>> dataset = StreamingDataset(remote=remote, local=local, shuffle=False, decoders=decoders)