Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿชฃ Training without Local Storage#

In practice, a lot of deep learning training happens in environments where servers lack persistent local storage. Thankfully, Composer is designed to natively support training in such contexts.

Without local persistent storage, checkpoints and datasets will need to be downloaded from the cloud, and all checkpoints, logs, metrics, and other artifacts will need to be backed up directly to the cloud.

Composer can automatically load checkpoints from cloud storage, convert traditional datasets into a format that can be streamed in when training, and asynchronously back up checkpoints and other artifacts without blocking the training loop. This tutorial will illustrate how.

Tutorial Goals and Concepts Covered#

The goal of this tutorial is to give you an in-depth look at a training workflow that does not require peristent local storage. To that end, weโ€™ll walk through how to convert a local training workflow for an MNIST classifier into one that does not require persistent disks.

In reality, small models with datasets that converge quicklyโ€”like MNISTโ€”would likely not require streaming datasets, since itโ€™s usually fast enough to download all data at the start of training. However, this is a simple example that highlights the steps involved to use these features in Composer.

Weโ€™ll cover:

  1. Prerequisites

  2. The local training workflow

  3. Storing and loading checkpoints and logs with the cloud

  4. Switching to streaming datasets

  5. Putting it all together

Letโ€™s get started!

## 1. Prerequisites

This tutorial requires access to an AWS S3 Bucket. If AWS credentials are not already available in your environment, then you will need to obtain an AWS_ACCESS_KEY_ID and an AWS_SECRET_ACCESS_KEY that have permission to upload to and download from an S3 Bucket.

Here, weโ€™ll define all configuration variables:

[ ]:
s3_bucket_name = 'my-bucket'  # The S3 bucket to use
# Give all objects in the bucket a prefix, allowing the
# bucket to be shared across training runs
bucket_prefix = 'composer-diskless-training-tutorial'

# If necessary, uncomment the following lines to set AWS credentials
# Do NOT include quotes
# %env AWS_ACCESS_KEY_ID ***
# %env AWS_SECRET_ACCESS_KEY ***
# %env AWS_DEFAULT_REGION us-west-2

# Also define local (temporary) folders that will be used for
# staging datasets, checkpoints, and log files before uploading
data_dir = '/tmp/data'
tensorboard_log_dir = '/tmp/tb_logs'
checkpoint_dir = '/tmp/checkpoints'

In addition, if you didnโ€™t already, install Composer with streaming support:

[ ]:
%pip install 'mosaicml[streaming,tensorboard]'

# To install from source instead of the last release, comment the command above and uncomment the following one.
# %pip install 'mosaicml[streaming,tensorboard] @ git+https://github.com/mosaicml/composer.git'

## 2. The Local Training Workflow

Letโ€™s first define our local training code for MNIST. This code downloads all data before training starts. We include the TensorBoard logger, which we will use to visualize results.

[ ]:
import shutil
import os

import torch.utils.data
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from composer import Trainer
from composer.loggers import TensorboardLogger
from composer.utils.reproducibility import seed_all
from composer.models.classify_mnist import mnist_model


# Configure the trainer

# Model and optimizer
def get_model_and_optimizer():
    # Set the seed before creating the model
    # for consistent initialization
    seed_all(42)
    model = mnist_model(num_classes=10)
    optimizer = SGD(model.parameters(), lr=0.01)
    return model, optimizer

model, optimizer = get_model_and_optimizer()

# Datasets
batch_size = 2048
train_dataset = MNIST(
    root=os.path.join(data_dir, 'imagefolder'),
    train=True,
    download=True,
    transform=ToTensor(),
)
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
)
eval_dataset = MNIST(
    root=os.path.join(data_dir, 'imagefolder'),
    train=False,
    download=True,
    transform=ToTensor(),
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset,
    batch_size=batch_size,
)

# TensorBoard Logger (for visualizing results)
def get_tensorboard_logger():
    shutil.rmtree(tensorboard_log_dir, ignore_errors=True)
    return TensorboardLogger(log_dir=tensorboard_log_dir, flush_interval=1)


# Clean up the checkpoint directory (if it already exists)
shutil.rmtree(checkpoint_dir, ignore_errors=True)

# Create the trainer
trainer = Trainer(
    model=model,
    max_duration='2ep',
    # Make training fast: Terminate each epoch after 5 batches
    train_subset_num_batches=5,
    optimizers=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    # Make evaluation fast: Evaluate on only five batches
    eval_subset_num_batches=5,
    # Flush log files every batch
    loggers=get_tensorboard_logger(),
    save_folder=checkpoint_dir,
    save_interval="1ba",
    run_name='local_training_run',
)
[ ]:
# Train!
trainer.fit()

# Close the trainer
trainer.close()

We trained our model. Now letโ€™s verify that we have our checkpoints and can visualize our TensorBoard logs.

[ ]:
print("Checkpoint Files")
!ls -al {checkpoint_dir}

# Visualize TensorBoard Logs
%load_ext tensorboard
import tensorboard.notebook

# NOTE: Tensorboard can take a few moments to appear
# If it does not show up after ~30 seconds, run this cell again
%tensorboard --logdir {tensorboard_log_dir}

## 3. Storing and loading checkpoints and logs with the cloud

Now that we verified our local training workflow works as intended, letโ€™s configure Composer to back up our checkpoints and TensorBoard TF Event files to the cloud.

First, a brief overview of Composer architecture:

  • Logger: Composer includes a centralized logger, which passes logged data to each LoggerDestination (more on that below). Logged data can be either metrics or artifacts. The logger is similar to Pythonโ€™s built-in logging.getLogger(...) but is designed to log structured metrics and artifacts in addition to just text.

  • LoggerDestination: Where logs are sent is specified via the loggers argument of the Trainer constructor. The centralized Logger (above) passes all metrics and artifacts to each LoggerDestination, which is responsible for handling and storing the data. For example, Composer includes LoggerDestinations for logging to files, TensorBoard, Weights & Biases, CometML, and Object Stores like S3. Not all LoggerDestinations support storing all types of data; for example, the ObjectStoreLogger only supports logging artifacts, whereas the FileLogger only supports logging files. Others, such as the WandBLogger, support both.

  • Metrics: A metric is a scalar, such as accuracy, that can be logged. Usually you would want to plot metrics over time (e.g., to see how accuracy improves over batches).

  • Artifacts: An artifact is a file generated throughout training, such as a checkpoint or log file. Each file is a separate artifact.

  • ObjectStore: The abstract ObjectStore class provides an API for uploading and downloading checkpoints. Composer includes object store implementations for S3, SFTP, and Libcloud. You can also write your own implementation by extending the base class if you are using a custom backend.

Weโ€™ll use these components together to back up our checkpoints and TensorBoard TF Event files to the cloud. Internally, the CheckpointSaver callback and TensorboardLogger pass all generated files to the Logger as artifacts by calling Logger.file_artifact. The Logger then passes these files to each LoggerDestination (specified via the loggers argument of the Trainer constructor). A LoggerDestination, which can implement the log_file_artifact method, is responsible for uploading the file to the cloud.

Here, our โ€œcloudโ€ will be an S3 bucket. To upload checkpoints and TensorBoard TF Event files to the bucket, weโ€™ll add the ObjectStoreLogger with the S3ObjectStore backend to our list of logger destinations. This class asynchronously uploads artifacts to an object store without blocking the training loop.

[ ]:
from composer.loggers import ObjectStoreLogger
from composer.utils.object_store import S3ObjectStore

# Clean the directories from the previous training run
shutil.rmtree(checkpoint_dir, ignore_errors=True)

model, optimizer = get_model_and_optimizer()

def get_object_store_logger():
    return ObjectStoreLogger(
        object_store_cls=S3ObjectStore,
        # Keyword arguments passed to the S3ObjectStore constructor
        object_store_kwargs={
            'bucket': s3_bucket_name,
            'prefix': bucket_prefix,
        },
        # In Jupyter, we set use_procs to False, since subprocess do not work
        # well within notebooks. Outside of Jupyter, it is recommended to let
        # use_procs default to True for performance
        use_procs=False,
    )

# Create the trainer
trainer = Trainer(
    model=model,
    max_duration='2ep',
    train_subset_num_batches=5,
    optimizers=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    eval_subset_num_batches=5,
    loggers=[
        get_object_store_logger(),
        get_tensorboard_logger(),
    ],
    save_folder=checkpoint_dir,
    save_interval="1ba",
    # Because we are uploading checkpoints to the cloud, we set
    # save_num_checkpoints_to_keep to 0 to delete them locally
    # after uploading to save disk space!
    save_num_checkpoints_to_keep=0,
    run_name='cloud_training_run',
)
[ ]:
# Train!
trainer.fit()

# Close the trainer, which will block on until all files have been uploaded
trainer.close()

# Remove all local TensorBoard traces, since they were uploaded to the cloud
shutil.rmtree(tensorboard_log_dir, ignore_errors=True)

Great! We trained our model again, and this time, uploaded the checkpoints and TensorBoard logs to our S3 bucket. Letโ€™s verify that our files exist in the bucket.

[ ]:
import boto3

s3 = boto3.client('s3')

def print_objects_in_bucket(bucket: str, prefix: str):
    response = s3.list_objects(Bucket=bucket, Prefix=prefix)
    keys = [obj['Key'] for obj in response['Contents']]
    keys.sort()
    for k in keys:
        print(k)

print_objects_in_bucket(s3_bucket_name, bucket_prefix)

Letโ€™s also take a look at visualizing TensorBoard from our S3 Bucket:

[ ]:
import tensorboard.notebook

%tensorboard --logdir s3://{s3_bucket_name}/{bucket_prefix}/tensorboard_logs

We can also resume training from a cloud checkpoint, without having to first download it.

To do so, weโ€™ll need to set the load_object_store argument of the Trainer constructor to our cloud_logger, and the load_path to the object name of our checkpoint file within the bucket.

[ ]:
# Define the cloud logger
cloud_logger = get_object_store_logger()
tensorboard_logger = get_tensorboard_logger()

model, optimizer = get_model_and_optimizer()

# Create the trainer
trainer = Trainer(
    model=model,
    max_duration='4ep',
    # Load the latest checkpoint
    load_path='cloud_training_run/checkpoints/latest-rank0',
    # Load from the cloud logger
    load_object_store=cloud_logger,
    train_subset_num_batches=5,
    optimizers=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    eval_subset_num_batches=5,
    loggers=[
        cloud_logger,
        tensorboard_logger,
    ],
    save_folder=checkpoint_dir,
    save_interval="1ba",
    save_num_checkpoints_to_keep=0,
    run_name='cloud_training_run',
)
[ ]:
# Train!
trainer.fit()

# Close the trainer, which will block until all files have been uploaded
trainer.close()

# Remove all local Tensorboard traces, since they were uploaded to the cloud
shutil.rmtree(tensorboard_log_dir, ignore_errors=True)

Letโ€™s take another look at TensorBoard, which should show another two epochs (four total):

[ ]:
import tensorboard.notebook

%tensorboard --logdir s3://{s3_bucket_name}/{bucket_prefix}/tensorboard_logs

As an alternative to setting the load_path and load_object_store, we could instead set autoresume=True. See our tutorial on autoresumption for more information about how this feature works.

## 4. Switching to Streaming Datasets

So far, we covered how to store and load checkpoints and other artifacts with cloud storage. But, we relied on all of our training data being available at the start. While it works to download all data to a local folder before training, this technique isnโ€™t scalable for large datasets. Training wouldnโ€™t begin until all data is downloaded, causing accelerators to idle.

To overcome this bottleneck, we can switch to streaming datasets. With streaming datasets, the dataset must be split into multiple shard files, each of which contains a subset of the samples. Shards are downloaded asynchronously (in separate subprocesses) while we train. Instead of waiting until all data is downloaded, we can begin training as soon as the first shard is downloaded. This feature lets training begin instantly!

For an in-depth walk-through on how to use streaming datasets, see our tutorial on training FaceSynthetics with a Streaming Dataloader. Here, weโ€™ll encode the MNIST dataset in the Composer streaming format and upload it to our S3 bucket.

[ ]:
# Write the streaming dataset
import os
import torch
import io
from typing import Dict, cast, Iterable
import struct
from tqdm.auto import tqdm

from composer.datasets.streaming import StreamingDatasetWriter

train_out_folder = os.path.join(data_dir, 'streaming', 'train')
eval_out_folder = os.path.join(data_dir, 'streaming', 'eval')

train_remote = S3ObjectStore(
    bucket=s3_bucket_name,
    prefix=bucket_prefix + '/train',
)

eval_remote = S3ObjectStore(
    bucket=s3_bucket_name,
    prefix=bucket_prefix + '/eval',
)


shutil.rmtree(train_out_folder, ignore_errors=True)
shutil.rmtree(eval_out_folder, ignore_errors=True)
os.makedirs(train_out_folder, exist_ok=True)
os.makedirs(eval_out_folder, exist_ok=True)

fields = ['i', 'x', 'y']

def encode_sample(i: int, x: torch.Tensor, y: int) -> Dict[str, bytes]:
    """Encode a (x,y) sample into a dictionary."""
    x_buffer = io.BytesIO()
    torch.save(x, x_buffer)
    x_buffer.seek(0)
    # See https://docs.python.org/3/library/struct.html#format-characters for
    # struct format characters
    return {
        # The (optional) sample index, encoded as a uint64 (Q format code)
        'i': struct.pack('Q', i),
        # The sample input, in bytes
        'x': x_buffer.read(),
        # The class index, encoded as a uint64 (Q format code)
        'y': struct.pack('Q', y),
    }

print("Writing training dataset")
with StreamingDatasetWriter(train_out_folder, fields, remote=train_remote) as out:
    for i, (x, y) in tqdm(
        enumerate(cast(Iterable, train_dataset)),
        total=len(train_dataset),
    ):
        out.write_sample(encode_sample(i, x, y))
print(f"Uploaded training dataset to s3://{s3_bucket_name}/{bucket_prefix}/train")

print("Writing evaluation dataset")
with StreamingDatasetWriter(eval_out_folder, fields, remote=eval_remote) as out:
    for i, (x, y) in tqdm(
        enumerate(cast(Iterable, eval_dataset)),
        total=len(eval_dataset),
    ):
        out.write_sample(encode_sample(i, x, y))
print(f"Uploaded evaluation dataset to s3://{s3_bucket_name}/{bucket_prefix}/eval")

Letโ€™s take a look at the shard files

[ ]:
print("Training shard files")
print_objects_in_bucket(s3_bucket_name, f'{bucket_prefix}/train')

print("Evaluation shard files")
print_objects_in_bucket(s3_bucket_name, f'{bucket_prefix}/eval')

Now, letโ€™s define our MNISTStreamingDataset subclass so we can decode the samples from bytes into an (x, y) tuple similar to a Torchvision-style dataset

[ ]:
from composer.datasets.streaming import StreamingDataset
from typing import Tuple

# Define the decoders for the dataset
def decode_x(x: bytes):
    x_buffer = io.BytesIO(x)
    x = torch.load(x_buffer)
    return x

def decode_y(y: bytes) -> int:
    y, = struct.unpack('Q', y)
    assert isinstance(y, int)
    return y

decoders = {
    'x': decode_x,
    'y': decode_y,
}

# Create a custom subclass of StreamingDataset to automatically unpack
# samples into tuples
class MNISTStreamingDataset(StreamingDataset):
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        # Overriding __getitem__ to unpack the dictionary
        # from encode_sample into an (x, y) tuple
        sample_dict = super().__getitem__(idx)
        return sample_dict['x'], sample_dict['y']

Next, letโ€™s train using our streaming version of the MNIST dataset!

[ ]:
import torch
import torch.utils.data

model, optimizer = get_model_and_optimizer()

# Redefine the dataset and dataloader to use the streaming format
train_dataset = MNISTStreamingDataset(
    remote=f's3://{s3_bucket_name}/{bucket_prefix}/train',
    local=os.path.join(data_dir, 'streaming_cache', 'train'),
    shuffle=True,
    decoders=decoders,
    batch_size=batch_size,
)

eval_dataset = MNISTStreamingDataset(
    remote=f's3://{s3_bucket_name}/{bucket_prefix}/eval',
    local=os.path.join(data_dir, 'streaming_cache', 'eval'),
    shuffle=True,
    decoders=decoders,
    batch_size=batch_size,
)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset,
    batch_size=batch_size,
)

# Create the trainer
trainer = Trainer(
    model=model,
    max_duration='2ep',
    train_subset_num_batches=5,
    optimizers=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    eval_subset_num_batches=5,
    loggers=[
        get_object_store_logger(),
        get_tensorboard_logger(),
    ],
    save_folder=checkpoint_dir,
    save_interval="1ba",
    save_num_checkpoints_to_keep=0,
    run_name='streaming_training_run',
)
[ ]:
# Train!
trainer.fit()

# Close the trainer, which will block on until all files have been uploaded
trainer.close()

Letโ€™s take another look at our TensorBoard plots.

Note: the loss and accuracy will slightly differ from the local dataset run because the streaming dataset shuffling is nondeterministic โ€“ it depends on the order in which shards are downloaded.

[ ]:
import tensorboard.notebook

%tensorboard --logdir s3://{s3_bucket_name}/{bucket_prefix}/tensorboard_logs

Congratulations! You trained your first model without relying on a persistent local disk!

## 5. Putting it all together

In this tutorial, we walked through storing and loading checkpoints with the cloud, converting existing datasets into a streaming format, and training a model using both of these features.

Below is a complete example, showing everything we did in one cell. Feel free to use this as a reference for your own training workflows:

[ ]:
import io
import os
import struct
import shutil
import time
from typing import Dict, cast, Iterable, Tuple

import torch
from torch.optim import SGD
import torch.utils.data
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm.auto import tqdm

from composer import Trainer
from composer.models import mnist_model
from composer.utils.reproducibility import seed_all
from composer.utils.object_store import S3ObjectStore
from composer.loggers import ObjectStoreLogger, TensorboardLogger
from composer.datasets.streaming import StreamingDatasetWriter, StreamingDataset


###############
# Configuration
###############

# # The S3 bucket to use
# s3_bucket_name = 'my-bucket'

# # Prefix all objects in the bucket with key,
# # allowing the bucket can be shared across training runs
# bucket_prefix = 'composer-diskless-training-tutorial'

# # If necessary, uncomment the following lines to set AWS credentials
# # Do NOT include quotes
# %env AWS_ACCESS_KEY_ID ***
# %env AWS_SECRET_ACCESS_KEY ***
# %env AWS_DEFAULT_REGION us-west-2

# # Also define local (temporary) folders that will be used for
# # staging datasets, checkpoints, and log files before uploading
# data_dir = '/tmp/data'
# tensorboard_log_dir = '/tmp/tb_logs'
# checkpoint_dir = '/tmp/checkpoints'

############################
# Writing Streaming Datasets
############################

print("Writing the streaming datasets")

batch_size = 2048

local_train_dataset = MNIST(
    root=os.path.join(data_dir, 'imagefolder'),
    train=True,
    download=True,
    transform=ToTensor(),
)
local_eval_dataset = MNIST(
    root=os.path.join(data_dir, 'imagefolder'),
    train=False,
    download=True,
    transform=ToTensor(),
)

train_out_folder = os.path.join(data_dir, 'streaming', 'train')
eval_out_folder = os.path.join(data_dir, 'streaming', 'eval')

train_remote = S3ObjectStore(
    bucket=s3_bucket_name,
    prefix=bucket_prefix + '/train',
)

eval_remote = S3ObjectStore(
    bucket=s3_bucket_name,
    prefix=bucket_prefix + '/eval',
)

shutil.rmtree(train_out_folder, ignore_errors=True)
shutil.rmtree(eval_out_folder, ignore_errors=True)
os.makedirs(train_out_folder, exist_ok=True)
os.makedirs(eval_out_folder, exist_ok=True)

fields = ['i', 'x', 'y']

def encode_sample(i: int, x: torch.Tensor, y: int) -> Dict[str, bytes]:
    """Encode a (x,y) sample into a dictionary."""
    x_buffer = io.BytesIO()
    torch.save(x, x_buffer)
    x_buffer.seek(0)
    # See https://docs.python.org/3/library/struct.html#format-characters for
    # struct format characters
    return {
        # The (optional) sample index, encoded as a uint64 (Q format code)
        'i': struct.pack('Q', i),
        # The sample input, in bytes
        'x': x_buffer.read(),
        # The class index, encoded as a uint64 (Q format code)
        'y': struct.pack('Q', y),
    }

print("Writing training dataset")
with StreamingDatasetWriter(train_out_folder, fields, remote=train_remote) as out:
    for i, (x, y) in tqdm(
        enumerate(cast(Iterable, local_train_dataset)),
        total=len(local_train_dataset),
    ):
        out.write_sample(encode_sample(i, x, y))
print(f"Uploaded training dataset to s3://{s3_bucket_name}/{bucket_prefix}/train")

print("Writing evaluation dataset")
with StreamingDatasetWriter(eval_out_folder, fields, remote=eval_remote) as out:
    for i, (x, y) in tqdm(
        enumerate(cast(Iterable, local_eval_dataset)),
        total=len(local_eval_dataset),
    ):
        out.write_sample(encode_sample(i, x, y))
print(f"Uploaded evaluation dataset to s3://{s3_bucket_name}/{bucket_prefix}/eval")


#################################
# Loading from Streaming Datasets
#################################

def decode_x(x: bytes):
    x_buffer = io.BytesIO(x)
    x = torch.load(x_buffer)
    return x

def decode_y(y: bytes) -> int:
    y, = struct.unpack('Q', y)
    assert isinstance(y, int)
    return y

decoders = {
    'x': decode_x,
    'y': decode_y,
}

class MNISTStreamingDataset(StreamingDataset):
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        # Overriding __getitem__ to unpack the dictionary from encode_sample
        # into an (x, y) tuple
        sample_dict = super().__getitem__(idx)
        return sample_dict['x'], sample_dict['y']


#######################
# Trainer Configuration
#######################

run_name = f'{int(time.time())}-final-training-run'

seed_all(42)  # Set the seed before creating the model
model = mnist_model(num_classes=10)
optimizer = SGD(model.parameters(), lr=0.01)

shutil.rmtree(tensorboard_log_dir, ignore_errors=True)
shutil.rmtree(checkpoint_dir, ignore_errors=True)


train_dataset = MNISTStreamingDataset(
    remote=f's3://{s3_bucket_name}/{bucket_prefix}/train',
    local=os.path.join(data_dir, 'streaming_cache', 'train'),
    shuffle=True,
    decoders=decoders,
    batch_size=batch_size,
)

eval_dataset = MNISTStreamingDataset(
    remote=f's3://{s3_bucket_name}/{bucket_prefix}/eval',
    local=os.path.join(data_dir, 'streaming_cache', 'eval'),
    shuffle=True,
    decoders=decoders,
    batch_size=batch_size,
)

train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset,
    batch_size=batch_size,
)

cloud_logger = ObjectStoreLogger(
    object_store_cls=S3ObjectStore,
    # Keyword arguments passed to the S3ObjectStore constructor
    object_store_kwargs={
        'bucket': s3_bucket_name,
        'prefix': bucket_prefix,
    },
    use_procs=False,
)

trainer = Trainer(
    model=model,
    max_duration='2ep',
    # Make training fast: Terminate each epoch after 5 batches
    train_subset_num_batches=5,
    optimizers=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    # Make evaluation fast: Evaluate only on five batches
    eval_subset_num_batches=5,
    loggers=[
        cloud_logger,
        TensorboardLogger(log_dir=tensorboard_log_dir, flush_interval=1),
    ],
    save_folder=checkpoint_dir,
    save_interval="1ba",
    # Because we are uploading checkpoints to the cloud,
    # delete them locally to save disk space!
    save_num_checkpoints_to_keep=0,
    run_name=run_name,
)

########
# Train!
########

print("Training from the beginning")
trainer.fit()
trainer.close()


#################################
# Train again (from a checkpoint)
#################################

trainer = Trainer(
    model=model,
    max_duration='4ep',
    load_path=f'{run_name}/checkpoints/latest-rank0',
    load_object_store=cloud_logger,
    train_subset_num_batches=5,
    optimizers=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    eval_subset_num_batches=5,
    loggers=[
        cloud_logger,
        TensorboardLogger(log_dir=tensorboard_log_dir, flush_interval=1),
    ],
    save_folder=checkpoint_dir,
    save_interval="1ba",
    save_num_checkpoints_to_keep=0,
    run_name=run_name,
)

print("Training from a checkpoint")
trainer.fit()
trainer.close()
[ ]:
###################
# Visualize Results
###################

import tensorboard.notebook
%load_ext tensorboard

%tensorboard --logdir s3://{s3_bucket_name}/{bucket_prefix}/tensorboard_logs

What Next?#

Wow, that was a lot! But now youโ€™re an expert on using Composer without persistent local storage (or, youโ€™re well on your way to becoming one).

If you havenโ€™t already, please continue to explore our tutorials! Here are a couple suggestions:

Come get involved with MosaicML!#

Weโ€™d love for you to get involved with the MosaicML community in any of these ways:

Star Composer on GitHub#

Help make others aware of our work by starring Composer on GitHub.

Join the MosaicML Slack#

Head on over to the MosaicML slack to join other ML efficiency enthusiasts. Come for the paper discussions, stay for the memes!

Contribute to Composer#

Is there a bug you noticed or a feature youโ€™d like? File an issue or make a pull request!