โœ… Checkpointing#

Composer can be configured to automatically save training checkpoints by passing the argument save_folder when creating the Trainer. The save_folder can be a relative path, in which case checkpoints will be stored in CWD/runs/<timestamp>/<rank>/<save_folder>. Absolute paths will be used as-is.

By default, checkpoints are saved every epoch, but can be configured using the save_interval argument. Specify save_interval="10ep" to save every 10 epochs or save_interval="500ba" to save every 500 batches/steps.

from composer import Trainer

trainer = Trainer(model=model,
                  train_dataloader=dataloader,
                  max_duration="1ep",
                  save_folder="/path/to/checkpoints",
                  save_interval="1ep")  # Save checkpoints every epoch
trainer.fit()

The above code will train a model for 1 epoch, and then save the checkpoint.

Anatomy of a checkpoint#

The above code, when run, will produce the checkpoints below:

os.listdir(trainer.checkpoint_saver.checkpoint_folder)

['ep1.pt']

Opening one of those checkpoints, youโ€™ll see:

state_dict = torch.load(
    os.path.join(trainer.checkpoint_saver.checkpoint_folder, "ep1.pt")
)
print(f"Top level keys: {list(state_dict.keys())}")
print(f"state keys: {list(state_dict['state'].keys())}")

>>> Top level keys: ['rng', 'state']
>>> Keys: ['model', 'timer', 'optimizers', 'schedulers', 'scaler', 'algorithms', 'callbacks', 'rng', 'rank_zero_seed', 'is_model_ddp']

At the top level, we see details on the current RNG state and the trainer.state.

Under the "state" key, we see:

  1. "model": Model weights

  2. "_optimizers": Optimizer state

  3. "_schedulers": Scheduler state

  4. "_algorithms": Any algorithm state

These are the most important keys to be aware of. There are several others that are required to ensure that you can pick back up where you left off.

Resume training#

To resume training from a previous checkpoint, pass the checkpoint file path to the Trainer with the load_path_format argument. This should be an absolute path.

When the Trainer is initialized, all the state information will be restored from the checkpoint and trainer.fit() will continue training from where the checkpoint left off.

trainer = Trainer(model=model,
                  train_dataloader=dataloader,
                  eval_dataloader=None,
                  max_duration="90ep",
                  load_path_format="/path/to/checkpoint/ep25.pt")
trainer.fit()

The above code will load the checkpoint from epoch 25, and continue training for another 65 epochs (to reach 90 epochs total).

Different model or optimizer objects passed into the trainer when resume will be respected. However, an error will be raised if the weights or state from the checkpoint are not compatible with these new objects.

..note

Only the following attributes from :class:`.State` will be serialized and loaded:

.. code:: python

    serialized_attributes = [
            "model",
            "optimizers",
            "schedulers",
            "algorithms",
            "callbacks",
            "scaler",
            "timer",
        ]

All other trainer arguments (e.g. ``max_duration`` or ``precision``) will use
the defaults or what is passed in during the trainer creation.

Fine-tuning#

The Trainer will only load the model weights from the checkpoint if load_weights_only=True. This is especially useful for model finetuning, since the rest of the trainerโ€™s state no longer applies.

ft_trainer = Trainer(model=model,
                     train_dataloader=finetune_dataloader,
                     eval_dataloader=None,
                     max_duration="10ep",
                     load_path_format="/path/to/checkpoint/ep50.pt",
                     load_weights_only=True)

This example will load only the model weights from epoch 50, and then continue training on the finetuned dataloader for 10 epochs.

Loading weights externally#

The model weights are located at state_dict["state"]["model"] within the stored checkpoint. To load them into a model outside of a Trainer, use torch.load():

model = MyModel()
state_dict = torch.load("/path/to/checkpoint/ep15.pt")
model.load_state_dict(state_dict["state"]["model"])

Uploading to Object Store#

Checkpoints can also be saved to and loaded from your object store of choice (e.g. AWS S3 or Google Cloud Storage). Writing checkpoints to an object store is a two-step process. The checkpoints are first written to the local filesystem, and then the RunDirectoryUploader callback will upload to the object store.

Note

We use libcloud to connect to the remote object stores, so be sure to have the Python package apache-libcloud installed.

For this, the ObjectStoreProvider needs to be configured with the following arguments:

  • provider: The name of the object store provider, as recognized by libcloud. See available providers here.

  • container: The name of the container (i.e. โ€œbucketโ€) to use.

To prevent accidental leakage of API keys, your secrets must be provided indirectly through environment variables. Set these in your environment and provide the following environment variable names:

  • key_environ: The environment variable where your username is stored. For example, the GCS access key.

  • secret_environ: The environment variable where your secret is stored. For example, the GCS secret that is paired with the above access key for requests.

The object store also accepts these common optional arguments:

  • host: The specific hostname for the cloud provider, letting you override the default value provided by libcloud.

  • port: The port for the cloud provider

  • region: The region to use for the cloud provider

If your cloud provider requires additional parameters, pass them as a dictionary under the key extra_init_kwargs.

Once youโ€™ve configured your object store properly per above, all thatโ€™s left is to add the RunDirectoryUploader as a callback.

Letโ€™s put all this together below:

import uuid
from composer.callbacks import RunDirectoryUploader
from composer.utils.object_store import ObjectStoreProviderHparams

credentials = {"provider": "GOOGLE_STORAGE",
               "container": "checkpoints-debugging",
               "key_environ": "GCE_KEY",
               "secret_environ": "GCE_SECRET"}
hp = ObjectStoreProviderHparams(**credentials)

prefix = f"my-model-{str(uuid.uuid4())[:6]}"
store_uploader = RunDirectoryUploader(hp, object_name_prefix=prefix)

trainer = Trainer(model=model,
                  train_dataloader=dataloader,
                  eval_dataloader=None,
                  max_duration="90ep",
                  save_folder="checkpoints",
                  callbacks=[store_uploader])

This will train your model, saving the checkpoints locally, and also upload them to Google Storage buckets using the username from GCS_KEY and the secrets from GCS_SECRET in your environment variables.

Loading from Object Store#

Checkpoints saved to an object store can also be loaded in the same way as files saved on disk. Provide the ObjectStoreProviderHparams to the trainerโ€™s load_object_store argument. The load_path_format argument should be the path to the checkpoint file within the container/bucket.

from composer.utils.object_store import ObjectStoreProviderHparams

credentials = {"provider": "GOOGLE_STORAGE",
               "container": "checkpoints-debugging",
               "key_environ": "GCS_KEY",
               "secret_environ": "GCS_SECRET"}
hp = ObjectStoreProviderHparams(
    provider="GOOGLE_STORAGE",
    container="checkpoints-debugging",
    key_environ="GCS_KEY",
    secret_environ="GCS_SECRET",
)
object_store = hp.initialize_object()

From there we can fine-tune with:

new_trainer = Trainer(model=model,
                      train_dataloader=finetune_dataloader,
                      eval_dataloader=None,
                      max_duration="10ep",
                      load_path_format="simple/rank_0/checkpoints/ep1.tar",
                      load_object_store=object_store,
                      load_weights_only=True)
new_trainer.fit()

Trainer checkpoint API#

The Trainer has many arguments, and below we provide the API reference for the arguments that are specific to checkpoint loading and saving:

Loading#

  • load_path_format (str, optional): Path to a specific checkpoint to load. If not set (the default), then no checkpoint will be loaded. (default: None)

  • load_object_store (ObjectStoreProvider, optional): For loading from object stores (e.g. S3), this will be used to download the checkpoint. Ignored if load_path_format is not specified. (default: None)

  • load_weights_only (bool): Only load the model weights. Ignored if load_path_format is not specified. (default: False)

  • load_strict (bool): Ensure that the set of weights in the checkpoint and model must exactly match. Ignored if load_path_format is not specified. (default: False)

  • load_chunk_size (int): Chunk size (in bytes) to use when downloading checkpoints. Ignored if the load_path_format is not specified or it is a local file path. (default: 1,048,675)

  • load_progress_bar (bool): Display the progress bar for downloading the checkpoint. Ignored if load_path_format is not specified or if it is a local file path. (default: True)

Saving#

  • save_folder (str, optional): Folder path to save checkpoints, relative to the run directory. Set to None to not save checkpoints. (default: None)

  • save_interval (str or int): How often to save checkpoints. For example, set to โ€œ1epโ€ to save checkpoints every epoch, or โ€œ10baโ€ to save checkpoints every 10 batches. An integer will be assumed to be epochs. (default: 1ep)

  • save_compression (str): Compression algorithm to run on checkpoints. Can be gzip, bzip2, lzma, or left blank for no compression. (default: "" for no compression).

Object Store API#

class composer.utils.ObjectStoreProviderHparams(provider, container, key_environ=None, secret_environ=None, region=None, host=None, port=None, extra_init_kwargs=<factory>)[source]

ObjectStoreProvider hyperparameters.

Example

Hereโ€™s an example on how to connect to an Amazon S3 bucket. This example assumes:

  • The container is named named MY_CONTAINER.

  • The AWS Access Key ID is stored in an environment variable named AWS_ACCESS_KEY_ID.

  • The Secret Access Key is in an environmental variable named AWS_SECRET_ACCESS_KEY.

>>> provider_hparams = ObjectStoreProviderHparams(
...     provider="s3",
...     container="MY_CONTAINER",
...     key_environ="AWS_ACCESS_KEY_ID",
...     secret_environ="AWS_SECRET_ACCESS_KEY",
... )
>>> provider = provider_hparams.initialize_object()
>>> provider
<composer.utils.object_store.ObjectStoreProvider object at ...>
Parameters
  • provider (str) โ€“

    Cloud provider to use.

    See ObjectStoreProvider for documentation.

  • container (str) โ€“ The name of the container (i.e. bucket) to use.

  • key_environ (str, optional) โ€“

    The name of an environment variable containing the API key or username to use to connect to the provider. If no key is required, then set this field to None. (default: None)

    For security reasons, composer requires that the key be specified via an environment variable. For example, if your key is an environment variable called OBJECT_STORE_KEY that is set to MY_KEY, then you should set this parameter equal to OBJECT_STORE_KEY. Composer will read the key like this:

    >>> import os
    >>> params = ObjectStoreProviderHparams(key_environ="OBJECT_STORE_KEY")
    >>> key = os.environ[params.key_environ]
    >>> key
    'MY_KEY'
    

  • secret_environ (str, optional) โ€“

    The name of an environment variable containing the API secret or password to use for the provider. If no secret is required, then set this field to None. (default: None)

    For security reasons, composer requires that the secret be specified via an environment variable. For example, if your secret is an environment variable called OBJECT_STORE_SECRET that is set to MY_SECRET, then you should set this parameter equal to OBJECT_STORE_SECRET. Composer will read the secret like this:

    >>> import os
    >>> params = ObjectStoreProviderHparams(secret_environ="OBJECT_STORE_SECRET")
    >>> secret = os.environ[params.secret_environ]
    >>> secret
    'MY_SECRET'
    

  • region (str, optional) โ€“ Cloud region to use for the cloud provider. Most providers do not require the region to be specified. (default: None)

  • host (str, optional) โ€“ Override the hostname for the cloud provider. (default: None)

  • port (int, optional) โ€“ Override the port for the cloud provider. (default: None)

  • extra_init_kwargs (Dict[str, Any], optional) โ€“

    Extra keyword arguments to pass into the constructor for the specified provider. (default: None, which is equivalent to an empty dictionary)

RunDirectoryUploader API#

class composer.callbacks.RunDirectoryUploader(object_store_provider_hparams, object_name_prefix=None, num_concurrent_uploads=4, upload_staging_folder=None, use_procs=True, upload_every_n_batches=100)[source]

Callback to upload the run directory to a blob store.

This callback checks the run directory for new or modified files at the end of every epoch, and after every upload_every_n_batches batches. This callback detects new or modified files based on the file modification timestamp. Only files that have a newer last modified timestamp since the last upload will be uploaded.

Example
>>> osphparams = ObjectStoreProviderHparams(
...     provider="s3",
...     container="run-dir-test",
...     key_environ="OBJECT_STORE_KEY",
...     secret_environ="OBJECT_STORE_SECRET",
...     region="us-west-2",
...     )
>>> # construct trainer object with this callback
>>> run_directory_uploader = RunDirectoryUploader(osphparams)
>>> trainer = Trainer(
...     model=model,
...     train_dataloader=train_dataloader,
...     eval_dataloader=eval_dataloader,
...     optimizers=optimizer,
...     max_duration="1ep",
...     callbacks=[run_directory_uploader],
... )
>>> # trainer will run this callback whenever the EPOCH_END
>>> # is triggered, like this:
>>> _ = trainer.engine.run_event(Event.EPOCH_END)

Note

This callback blocks the training loop to copy files from the run_directory to the upload_staging_folder and to queue these files to the upload queues of the workers. Actual upload happens in the background. While all uploads happen in the background, here are some additional tips for minimizing the performance impact:

  • Ensure that upload_every_n_batches is sufficiently infrequent as to limit when the blocking scans of the run directory and copies of modified files. However, do not make it too infrequent in case if the training process unexpectedly dies, since data written after the last upload may be lost.

  • Set use_procs=True (the default) to use background processes, instead of threads, to perform the file uploads. Processes are recommended to ensure that the GIL is not blocking the training loop when performance CPU operations on uploaded files (e.g. computing and comparing checksums). Network I/O happens always occurs in the background.

  • Provide a RAM disk path for the upload_staging_folder parameter. Copying files to stage on RAM will be faster than writing to disk. However, you must have sufficient excess RAM on your system, or you may experience OutOfMemory errors.

Parameters
  • object_store_provider_hparams (ObjectStoreProviderHparams) โ€“

    ObjectStoreProvider hyperparameters object

    See ObjectStoreProviderHparams for documentation.

  • object_name_prefix (str, optional) โ€“

    A prefix to prepend to all object keys. An objectโ€™s key is this prefix combined with its path relative to the run directory. If the container prefix is non-empty, a trailing slash (โ€˜/โ€™) will be added if necessary. If not specified, then the prefix defaults to the run directory. To disable prefixing, set to the empty string.

    For example, if object_name_prefix = 'foo' and there is a file in the run directory named bar, then that file would be uploaded to foo/bar in the container.

  • num_concurrent_uploads (int, optional) โ€“ Maximum number of concurrent uploads. Defaults to 4.

  • upload_staging_folder (str, optional) โ€“ A folder to use for staging uploads. If not specified, defaults to using a TemporaryDirectory().

  • use_procs (bool, optional) โ€“ Whether to perform file uploads in background processes (as opposed to threads). Defaults to True.

  • upload_every_n_batches (int, optional) โ€“ Interval at which to scan the run directory for changes and to queue uploads of files. In addition, uploads are always queued at the end of the epoch. Defaults to every 100 batches.