โœ… Checkpointing#

Composer can be configured to automatically save training checkpoints by passing the argument save_folder when creating the Trainer.

To customize the filenames of checkpoints inside save_folder, you can set the save_filename argument. By default, checkpoints will be named like 'ep{epoch}-ba{batch}-rank{rank}' within the save_folder.

In addition, the trainer creates a symlink called 'latest-rank{rank}', which points to the latest saved checkpoint file. You can customize this symlink name by setting the save_latest_filename argument.

The save_folder, save_filename, and save_latest arguments are Python format strings, so you can customize the folder structure to include information such as the rank of the Python process or the current training progress. Please see the CheckpointSaver for the full list of available format variables.

For example:

from composer import Trainer

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="2ep",
    save_folder="./path/to/checkpoints",
    save_filename="ep{epoch}",
    save_latest_filename="latest",
    save_overwrite=True,
)

trainer.fit()

Save Interval#

By default, checkpoints are saved every epoch, but this interval can be configured using the save_interval argument. The save_interval can be an integer (interpreted as a number of epochs), a time string (see the Time Guide for more information), or a function that takes (State, Event) and returns whether a checkpoint should be saved.

For example:

  • save_interval=1 to save every epoch (the default).

  • save_interval="10ep" to save every 10 epochs.

  • save_interval="500ba" to save every 500 batches/steps.

  • save_interval=lambda state, event: state.timestamp.epoch > 50 and event == Event.EPOCH_CHECKPOINT to save every epoch, starting after the 50th epoch.

Putting this together, hereโ€™s how to save checkpoints:

from composer import Trainer

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="1ep",
    save_filename="ep{epoch}.pt",
    save_folder="./path/to/checkpoints",
    save_overwrite=True,
    save_interval="1ep",  # Save checkpoints every epoch
)
trainer.fit()

The above code will train a model for 1 epoch, and then save the checkpoint.

Anatomy of a Checkpoint#

The above code, when run, will produce the checkpoints below:

>>> trainer.saved_checkpoints
['./path/to/checkpoints/ep1.pt']
>>> latest_checkpoint = trainer.saved_checkpoints[-1]
>>> state_dict = torch.load(latest_checkpoint)
>>> list(state_dict)
['state', 'rng']
>>> list(state_dict['state'].keys())
['model', 'optimizers', 'schedulers', 'algorithms', 'callbacks', 'scaler', 'timestamp', 'rank_zero_seed', 'train_metrics', 'eval_metrics', 'run_name']

Resume training#

To resume training from a previous checkpoint, set the load_path argument of the Trainer to the checkpoint filepath. When the Trainer is initialized, the checkpoint state will be restored, and the Trainer.fit() will continue training from where the checkpoint left off.

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="90ep",
    save_overwrite=True,
    load_path="./path/to/checkpoints/ep25.pt",
)
trainer.fit()

The above code will load the checkpoint from epoch 25 and continue training for another 65 epochs (to reach 90 epochs total).

Different model or optimizer objects that are passed into the trainer when it is resumed will be respected. However, an error will be raised if the weights or state from the checkpoint are not compatible with these new objects.

Note

Only the attributes in State.serialized_attributes are serialized and loaded. By default, they are:

Attribute

Description

model

The model under training.

optimizers

The optimizers being used to train the model.

schedulers

The learning rate schedulers.

algorithms

The algorithms used for training.

callbacks

The callbacks used for training.

scaler

The gradient scaler in use for mixed precision training.

timestamp

The timestamp that tracks training loop progress.

rank_zero_seed

The seed of the rank zero process.

train_metrics

The current training metrics.

eval_metrics

The current validation metrics.

run_name

The run name for training.

All other trainer arguments (e.g. max_duration or precision) will use either the defaults or what is passed in when reconstructing the trainer.

Saving for Inference#

By default, the Trainer stores the entire training state in each checkpoint. If you would like to store only the model weights in a checkpoint, set save_weights_only=True.

from composer.trainer import Trainer

trainer = Trainer(
    ...,
    save_folder="checkpoints",
    save_weights_only=True,
    save_overwrite=True,
)

trainer.fit()

Saving Multiple Checkpoint Types#

To save multiple checkpoint types, such as full checkpoints and weights-only checkpoints, the CheckpointSaver can be passed directly into the callbacks argument of the trainer. Each CheckpointSaver can have its own save folder, interval, and other parameters.

When configuring checkpoints via the callbacks, it is not necessary to specify the save_folder or other checkpoint saving parameters directly on the trainer.

from composer.trainer import Trainer
from composer.callbacks import CheckpointSaver

trainer = Trainer(
    ...,
    callbacks=[
        CheckpointSaver(
            folder='full_checkpoints',
            save_interval='5ep',
            overwrite=True,
            num_checkpoints_to_keep=1,  # only keep the latest, full checkpoint
        ),
        CheckpointSaver(
            folder='weights_only_checkpoints',
            weights_only=True,
            overwrite=True,
        ),
    ],
)

trainer.fit()

Fine-tuning#

The Trainer will only load the model weights from the checkpoint if load_weights_only=True or if the checkpoint was saved with save_weights_only=True. This is especially useful for model fine-tuning, since the rest of the trainerโ€™s state no longer applies.

If the fine-tuned model contains different parameter names than the model in the checkpoint, set load_strict=False to ignore mismatches in model parameter names between the serialized model state and new model object. Parameters with the same name are expected to have the same shape and will have their state restored. Parameters with different names will ignored.

ft_trainer = Trainer(
    model=finetune_model,
    train_dataloader=finetune_dataloader,
    max_duration="10ep",
    load_path="./path/to/checkpoints/ep50.pt",
    load_weights_only=True,
    load_strict_model_weights=False,
)

ft_trainer.fit()

This example will load only the model weights from epoch 1 and then continue training on the fine-tuned dataloader for 10 epochs.

Loading Weights Externally#

The model weights are located at state_dict["state"]["model"] within the stored checkpoint. To load them into a model outside of a Trainer, use torch.load():

model = Model(num_channels, num_classes)
state_dict = torch.load("./path/to/checkpoints/ep1.pt")
model.load_state_dict(state_dict["state"]["model"])

Uploading Checkpoints to Object Store#

Checkpoints can also be saved to and loaded from your object store of choice (e.g. AWS S3 or Google Cloud Storage). Writing checkpoints to an object store is a two-step process. The checkpoints are first written to the local filesystem, and then the RemoteUploaderDownloader logger will upload checkpoints to the specified object store.

Behind the scenes, the RemoteUploaderDownloader uses Apache Libcloud.

The easiest way to upload checkpoints to S3 is to prefix your save_folder with 's3://'. All other checkpoint arguments remain the same. For example, save_filename will be the name of the checkpoint file that gets uploaded to the S3 URI that you specified.

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='s3://my_bucket/checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
)

trainer.fit()

This will train your model, saving the checkpoints locally, upload them to the S3 Bucket my_bucket, and delete the checkpoints from the local disk. The checkpoints will be located on S3 inside your bucket as checkpoints/ep3.pt for third epochโ€™s checkpoints, for example. The full URI in this case would be: s3://my_bucket/checkpoints/ep3.pt.

  • save_num_checkpoints_to_keep: Set this parameter to remove checkpoints from the local disk after they have been uploaded. For example, setting this parameter to 1 will only keep the latest checkpoint locally; setting it to 0 will remove each checkpoint after it has been uploaded. Checkpoints are never deleted from object stores.

  • save_remote_file_name: To customize how checkpoints are named in the cloud bucket, modify this parameter. By default, they will be named as '{run_name}/checkpoints/ep{epoch}-ba{batch}-rank{rank}'. See the CheckpointSaver documentation for the available format variables.

This is equivalent to creating a RemoteUploaderDownloader object and adding it to loggers. This a more involved operation, but is necessary for uploading checkpoints to other cloud object stores, like GCS.

from composer.loggers import RemoteUploaderDownloader
from composer.trainer import Trainer

remote_uploader_downloader = RemoteUploaderDownloader(
    bucket_uri="libcloud://checkpoint-debugging",
    backend_kwargs={
        "provider": "s3",  # The Apache Libcloud provider name
        "container": "checkpoint-debugging",  # The name of the cloud container (i.e. bucket) to use.
        "provider_kwargs": {  # The Apache Libcloud provider driver initialization arguments
            'key': 'provider_key',  # The cloud provider key.
            'secret': '*******',  # The cloud provider secret.
            # Any additional arguments required for the cloud provider.
        },
    },
)

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
    loggers=[remote_uploader_downloader],
)

trainer.fit()

There are a few additional trainer arguments which can be helpful to configure:

  • save_num_checkpoints_to_keep: Set this parameter to remove checkpoints from the local disk after they have been uploaded. For example, setting this parameter to 1 will only keep the latest checkpoint locally; setting it to 0 will remove each checkpoint after it has been uploaded. Checkpoints are never deleted from object stores.

Saving Checkpoints to Google Cloud Storage (GCS)#

To save checkpoints to GCS, make sure to create a RemoteUploaderDownloader instance and pass it to the loggers argument of Trainer.

Make sure to input your HMAC access id and secret as โ€˜keyโ€™ and โ€˜secretโ€™ respectively to the RemoteUploaderDownloader constructor like so:

from composer.loggers import RemoteUploaderDownloader
from composer.trainer import Trainer

remote_uploader_downloader = RemoteUploaderDownloader(
    bucket_uri="libcloud://my-gcs-bucket",
    backend_kwargs={
        "provider": "google_storage",  # The Apache Libcloud provider name
        "container": "my-gcs-bucket",  # The name of the cloud container (i.e. bucket) to use.
        "provider_kwargs": {  # The Apache Libcloud provider driver initialization arguments
            'key': 'your-HMAC-access-id',  # The cloud provider key.
            'secret': 'your-HMAC-secret',  # The cloud provider secret.
            # Any additional arguments required for the cloud provider.
        },
    },
)

If you donโ€™t want to directly enter your access id and secret in plaintext, you can instead store them as environment variables and pass the name of the environment variables as shown below:

from composer.loggers import RemoteUploaderDownloader
from composer.trainer import Trainer

remote_uploader_downloader = RemoteUploaderDownloader(
    bucket_uri="libcloud://my-gcs-bucket",
    backend_kwargs={
        "provider": "google_storage",
        "container": "my-gcs-bucket",
        "key_environ": "MY_HMAC_ACCESS_ID", # Name of env variable for HMAC access id.
        "secret_environ": "MY_HMAC_SECRET", # Name of env variable for HMAC secret.
    },
)

Putting it all together:

from composer.loggers import RemoteUploaderDownloader
from composer.trainer import Trainer

remote_uploader_downloader = RemoteUploaderDownloader(
    bucket_uri="libcloud://my-gcs-bucket",
    backend_kwargs={
        "provider": "google_storage",
        "container": "my-gcs-bucket",
        "key_environ": "MY_HMAC_ACCESS_ID", # Name of env variable for HMAC access id.
        "secret_environ": "MY_HMAC_SECRET", # Name of env variable for HMAC secret.
    },
)

trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration='10ep',
    save_folder='checkpoints',
    save_interval='1ep',
    save_overwrite=True,
    save_filename='ep{epoch}.pt',
    save_num_checkpoints_to_keep=0,  # delete all checkpoints locally
    loggers=[remote_uploader_downloader],
)

trainer.fit()

This will save checkpoints every epoch to the gs URI: gs://my-gcs-bucket/checkpoints. Each checkpoint will then be at the gs URI: gs://my-gcs-bucket/checkpoints/ep{epoch}.pt, where {epoch} will be filled in with the epoch of that checkpoint.

Loading Checkpoints from Object Store#

Checkpoints saved to an object store can also be loaded in the same way as files saved on disk. Provide the LibcloudObjectStore to the trainerโ€™s load_object_store argument (you can also provide the full RemoteUploaderDownloader object as well). The load_path argument should be the path to the checkpoint file within the container/bucket.

from composer.utils import LibcloudObjectStore
from composer.trainer import Trainer

object_store = LibcloudObjectStore(
    provider="s3",  # The Apache Libcloud provider name
    container="checkpoint-debugging",  # The name of the cloud container (i.e. bucket) to use.
    provider_kwargs={  # The Apache Libcloud provider driver initialization arguments
        'key': 'provider_key',  # The cloud provider key.
        'secret': '*******',  # The cloud provider secret.
        # Any additional arguments required for the cloud provider.
    },
)

new_trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="10ep",
    load_path="checkpoints/ep1.pt",
    load_object_store=object_store,
)

new_trainer.fit()

An easier way to load checkpoints from S3 specifically is to just use a URI starting with s3://. If you use the S3 URI, it is not necessary to specify a load_object_store. Note, that for other object stores like WandB or LibCloud, you must still specify a load_object_store.

new_trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration="10ep",
load_path="s3://checkpoint-debugging/checkpoints/ep1.pt",
)

new_trainer.fit()

This will load the first epochโ€™s checkpoints from S3 and resume training in the second epoch.

Loading Checkpoints from Google Cloud Storage (GCS)#

To load checkpoints from GCS, you need to once again create a RemoteUploaderDownloader instance, but this time make sure to pass the instance to the load_object_store argument of the Trainer. The load_path argument should be the path to the checkpoint file within the container/bucket.

Here is an example for loading the third epochโ€™s checkpoint from a GCS bucket:

from composer.loggers import RemoteUploaderDownloader
from composer.trainer import Trainer

remote_uploader_downloader = RemoteUploaderDownloader(
    bucket_uri="libcloud://my-gcs-bucket",
    backend_kwargs={
        "provider": "google_storage",
        "container": "my-gcs-bucket",
        "key_environ": "MY_HMAC_ACCESS_ID", # Name of env variable for HMAC access id.
        "secret_environ": "MY_HMAC_SECRET", # Name of env variable for HMAC secret.
    },
)

new_trainer = Trainer(
    model=model,
    train_dataloader=train_dataloader,
    max_duration="10ep",
    load_path="checkpoints/ep3.pt",
    load_object_store=remote_uploader_downloader,
)

new_trainer.fit()

This code will load the third epochโ€™s checkpoint from your GCS bucket and and resume the training run on the fourth epoch.

API Reference#

  • RemoteUploaderDownloader for saving checkpoints to cloud storage.

  • Trainer for the trainer checkpoint arguments.

  • CheckpointSaver for the CheckpointSaver arguments.

  • LibcloudObjectStore for setting up libcloud-supported object stores.

  • composer.utils.checkpoint for the underlying utilities to manually save and load checkpoints.