composer.datasets.streaming.format#

The StreamingDatsetIndex format that defines shard/sample metadata for StreamingDataset.

Functions

bytes_to_sample_dict

Load a sample dict from bytes and field names.

get_index_basename

Get the basename for a streaming dataset index.

get_shard_basename

Get the basename for a streaming dataset shard.

sample_dict_to_bytes

Dump a sample dict to bytes, given field names.

Classes

StreamingDatasetIndex

Streaming Dataset index file, containing all the info about shards and samples.

class composer.datasets.streaming.format.StreamingDatasetIndex(samples_per_shard, bytes_per_shard, bytes_per_sample, fields)[source]#

Streaming Dataset index file, containing all the info about shards and samples.

The shards are binary buffers with samples concatenated together. All the offset info across the whole dataset is contained in the index file. Workers read this file to calculate how much of which shards their slice is.

Each sample is a dict of str to bytes. All samples must contain the same dict keys (fields). These strings are stored in the index file for efficiency.

Parameters
  • samples_per_shard (NDArray[np.int64]) โ€“ Number of samples of each shard.

  • bytes_per_shard (NDArray[np.int64]) โ€“ Size in bytes of each shard.

  • bytes_per_sample (NDArray[np.int64]) โ€“ Size in bytes of each sample across all shards.

  • fields (List[str]) โ€“ The names of the samplesโ€™ fields in order.

dump(fp)[source]#

Dump a StreamingDatasetIndex to the file.

Parameters

fp (file) โ€“ The file to write.

dumps()[source]#

Dump a StreamingDatasetIndex to raw bytes.

Returns

bytes โ€“ The serialized form.

get_partition(world, batch_size=None)[source]#

Get the shards and sample range of a given partition of the dataset.

When batch_size is provided, worker indices will be constructed so that there is at most one incomplete batch at the end of each epoch. For example, if the DataLoader is reading over:

samples: [0, 1, 2, 3, 4, 5, 6, 7]
num_workers: 3
batch_size: 2
drop_last: True

but batch_size is not hinted to the StreamingDataset ahead of time, then the samples will by default be assigned like:

worker 0: [0, 1, 2]
worker 1: [3, 4, 5]
worker 2: [6, 7]

and will be read as batches like (with samples [2] and [5] dropped as incomplete):

batch 0: [0, 1]
batch 1: [3, 4]
batch 2: [6, 7]

The above is suboptimal because we could have dropped no samples. So when batch_size is provided as a hint, we assign samples like this:

worker 0: [0, 1, 2, 3]
worker 1: [4, 5]
worker 2: [6, 7]

which will be read as batches like:

batch 0: [0, 1]
batch 1: [4, 5]
batch 2: [6, 7]
batch 3: [2, 3]
Parameters
  • world (World) โ€“ Context about workers, devices, and nodes.

  • batch_size (Optional[int]) โ€“ Hint the batch_size that will be used on each deviceโ€™s DataLoader.

Returns
  • shards (List[int]) โ€“ The shards that this partition overlaps.

  • shards_to_download (List[int]) โ€“ The shards that this worker should download (subset of shards).

  • min_id (int) โ€“ The lowest sample ID of this partition.

  • max_id (int) โ€“ The highest sample ID of this partition.

classmethod load(fp)[source]#

Load a StreamingDatasetIndex from a file handle.

Parameters

fp (file) โ€“ The file to read.

Returns

cls โ€“ The loaded object.

classmethod loads(data)[source]#

Load a StreamingDatasetIndex from raw bytes.

Parameters

data (bytes) โ€“ The serialized form.

Returns

cls โ€“ The loaded object.

composer.datasets.streaming.format.bytes_to_sample_dict(data, keys)[source]#

Load a sample dict from bytes and field names.

Parameters
  • data (bytes) โ€“ The encoded sample data.

  • keys (List[str]) โ€“ The field names. Must be in the same order as the keys used when calling sample_dict_to_bytes().

Returns

Dict[str, bytes] โ€“ The decoded sample dict.

composer.datasets.streaming.format.get_index_basename()[source]#

Get the basename for a streaming dataset index.

Returns

str โ€“ Basename of file.

composer.datasets.streaming.format.get_shard_basename(shard)[source]#

Get the basename for a streaming dataset shard.

Parameters

shard (int) โ€“ Shard index.

Returns

str โ€“ Basename of file.

composer.datasets.streaming.format.read_array(fp, count, dtype)[source]#

Load the count items from the file handle, advancing its position.

Parameters
  • fp (BufferedIOBase) โ€“ File handle.

  • count (int) โ€“ Number of items to read.

  • dtype (type) โ€“ Item datatype.

Returns

np.ndarray โ€“ The read array.

composer.datasets.streaming.format.sample_dict_to_bytes(obj, keys)[source]#

Dump a sample dict to bytes, given field names.

Parameters
  • obj (Dict[str, bytes]) โ€“ The sample dict to encode.

  • keys (list of str) โ€“ The field names.

Returns

bytes โ€“ The encoded sample bytes.