โ Checkpointing#
Composer can be configured to automatically save training checkpoints by
passing the argument save_folder
when creating the
Trainer
. The save_folder
can be a relative path, in which case
checkpoints will be stored in
CWD/runs/<timestamp>/<rank>/<save_folder>
. Absolute paths will be used as-is.
By default, checkpoints are saved every epoch, but can be configured
using the save_interval
argument. Specify save_interval="10ep"
to save every 10 epochs or save_interval="500ba"
to save every
500 batches/steps.
from composer import Trainer
trainer = Trainer(model=model,
train_dataloader=dataloader,
max_duration="1ep",
save_folder="/path/to/checkpoints",
save_interval="1ep") # Save checkpoints every epoch
trainer.fit()
The above code will train a model for 1 epoch, and then save the checkpoint.
Anatomy of a checkpoint#
The above code, when run, will produce the checkpoints below:
os.listdir(trainer.checkpoint_saver.checkpoint_folder)
['ep1.pt']
Opening one of those checkpoints, youโll see:
state_dict = torch.load(
os.path.join(trainer.checkpoint_saver.checkpoint_folder, "ep1.pt")
)
print(f"Top level keys: {list(state_dict.keys())}")
print(f"state keys: {list(state_dict['state'].keys())}")
>>> Top level keys: ['rng', 'state']
>>> Keys: ['model', 'timer', 'optimizers', 'schedulers', 'scaler', 'algorithms', 'callbacks', 'rng', 'rank_zero_seed', 'is_model_ddp']
At the top level, we see details on the current RNG state and the
trainer.state
.
Under the "state"
key, we see:
"model"
: Model weights"_optimizers"
: Optimizer state"_schedulers"
: Scheduler state"_algorithms"
: Any algorithm state
These are the most important keys to be aware of. There are several others that are required to ensure that you can pick back up where you left off.
Resume training#
To resume training from a previous checkpoint, pass the
checkpoint file path to the Trainer
with the
load_path_format
argument. This should be an absolute path.
When the Trainer
is initialized, all the state
information will be restored from the checkpoint and
trainer.fit()
will continue training from where the checkpoint left off.
trainer = Trainer(model=model,
train_dataloader=dataloader,
eval_dataloader=None,
max_duration="90ep",
load_path_format="/path/to/checkpoint/ep25.pt")
trainer.fit()
The above code will load the checkpoint from epoch 25, and continue training for another 65 epochs (to reach 90 epochs total).
Different model
or optimizer
objects passed into the trainer when
resume will be respected. However, an error will be raised if the weights or
state from the checkpoint are not compatible with these new objects.
..note
Only the following attributes from :class:`.State` will be serialized and loaded:
.. code:: python
serialized_attributes = [
"model",
"optimizers",
"schedulers",
"algorithms",
"callbacks",
"scaler",
"timer",
]
All other trainer arguments (e.g. ``max_duration`` or ``precision``) will use
the defaults or what is passed in during the trainer creation.
Fine-tuning#
The Trainer
will only load the model weights from the checkpoint if
load_weights_only=True
. This is especially useful for model finetuning,
since the rest of the trainerโs state no longer applies.
ft_trainer = Trainer(model=model,
train_dataloader=finetune_dataloader,
eval_dataloader=None,
max_duration="10ep",
load_path_format="/path/to/checkpoint/ep50.pt",
load_weights_only=True)
This example will load only the model weights from epoch 50, and then continue training on the finetuned dataloader for 10 epochs.
Loading weights externally#
The model weights are located at state_dict["state"]["model"]
within
the stored checkpoint. To load them into a model outside of
a Trainer
, use torch.load()
:
model = MyModel()
state_dict = torch.load("/path/to/checkpoint/ep15.pt")
model.load_state_dict(state_dict["state"]["model"])
Uploading to Object Store#
Checkpoints can also be saved to and loaded from your object store of
choice (e.g. AWS S3 or Google Cloud Storage). Writing checkpoints to an
object store is a two-step process. The checkpoints are first written to
the local filesystem, and then the RunDirectoryUploader
callback
will upload to the object store.
Note
We use libcloud
to connect to the remote object stores, so be
sure to have the Python package apache-libcloud
installed.
For this, the ObjectStoreProvider
needs to be configured with
the following arguments:
provider
: The name of the object store provider, as recognized bylibcloud
. See available providers here.container
: The name of the container (i.e. โbucketโ) to use.
To prevent accidental leakage of API keys, your secrets must be provided indirectly through environment variables. Set these in your environment and provide the following environment variable names:
key_environ
: The environment variable where your username is stored. For example, the GCS access key.secret_environ
: The environment variable where your secret is stored. For example, the GCS secret that is paired with the above access key for requests.
The object store also accepts these common optional arguments:
host
: The specific hostname for the cloud provider, letting you override the default value provided bylibcloud
.port
: The port for the cloud providerregion
: The region to use for the cloud provider
If your cloud provider requires additional parameters, pass them as a
dictionary under the key extra_init_kwargs
.
Once youโve configured your object store properly per above, all thatโs
left is to add the RunDirectoryUploader
as a callback.
Letโs put all this together below:
import uuid
from composer.callbacks import RunDirectoryUploader
from composer.utils.object_store import ObjectStoreProviderHparams
credentials = {"provider": "GOOGLE_STORAGE",
"container": "checkpoints-debugging",
"key_environ": "GCE_KEY",
"secret_environ": "GCE_SECRET"}
hp = ObjectStoreProviderHparams(**credentials)
prefix = f"my-model-{str(uuid.uuid4())[:6]}"
store_uploader = RunDirectoryUploader(hp, object_name_prefix=prefix)
trainer = Trainer(model=model,
train_dataloader=dataloader,
eval_dataloader=None,
max_duration="90ep",
save_folder="checkpoints",
callbacks=[store_uploader])
This will train your model, saving the checkpoints locally, and also
upload them to Google Storage buckets using the username from
GCS_KEY
and the secrets from GCS_SECRET
in your environment
variables.
Loading from Object Store#
Checkpoints saved to an object store can also be loaded in the
same way as files saved on disk. Provide the
ObjectStoreProviderHparams
to the trainerโs load_object_store
argument. The load_path_format
argument
should be the path to the checkpoint file within the container/bucket.
from composer.utils.object_store import ObjectStoreProviderHparams
credentials = {"provider": "GOOGLE_STORAGE",
"container": "checkpoints-debugging",
"key_environ": "GCS_KEY",
"secret_environ": "GCS_SECRET"}
hp = ObjectStoreProviderHparams(
provider="GOOGLE_STORAGE",
container="checkpoints-debugging",
key_environ="GCS_KEY",
secret_environ="GCS_SECRET",
)
object_store = hp.initialize_object()
From there we can fine-tune with:
new_trainer = Trainer(model=model,
train_dataloader=finetune_dataloader,
eval_dataloader=None,
max_duration="10ep",
load_path_format="simple/rank_0/checkpoints/ep1.tar",
load_object_store=object_store,
load_weights_only=True)
new_trainer.fit()
Trainer checkpoint API#
The Trainer
has many arguments, and below we provide the API reference
for the arguments that are specific to checkpoint loading and saving:
Loading#
load_path_format
(str
, optional): Path to a specific checkpoint to load. If not set (the default), then no checkpoint will be loaded. (default:None
)load_object_store
(ObjectStoreProvider
, optional): For loading from object stores (e.g. S3), this will be used to download the checkpoint. Ignored ifload_path_format
is not specified. (default:None
)load_weights_only
(bool
): Only load the model weights. Ignored ifload_path_format
is not specified. (default:False
)load_strict
(bool
): Ensure that the set of weights in the checkpoint and model must exactly match. Ignored ifload_path_format
is not specified. (default:False
)load_chunk_size
(int
): Chunk size (in bytes) to use when downloading checkpoints. Ignored if theload_path_format
is not specified or it is a local file path. (default:1,048,675
)load_progress_bar
(bool
): Display the progress bar for downloading the checkpoint. Ignored ifload_path_format
is not specified or if it is a local file path. (default:True
)
Saving#
save_folder
(str
, optional): Folder path to save checkpoints, relative to the run directory. Set toNone
to not save checkpoints. (default:None
)save_interval
(str
orint
): How often to save checkpoints. For example, set to โ1epโ to save checkpoints every epoch, or โ10baโ to save checkpoints every 10 batches. An integer will be assumed to be epochs. (default:1ep
)save_compression
(str
): Compression algorithm to run on checkpoints. Can begzip
,bzip2
,lzma
, or left blank for no compression. (default:""
for no compression).
Object Store API#
- class composer.utils.ObjectStoreProviderHparams(provider, container, key_environ=None, secret_environ=None, region=None, host=None, port=None, extra_init_kwargs=<factory>)[source]
ObjectStoreProvider
hyperparameters.Example
Hereโs an example on how to connect to an Amazon S3 bucket. This example assumes:
The container is named named
MY_CONTAINER
.The AWS Access Key ID is stored in an environment variable named
AWS_ACCESS_KEY_ID
.The Secret Access Key is in an environmental variable named
AWS_SECRET_ACCESS_KEY
.
>>> provider_hparams = ObjectStoreProviderHparams( ... provider="s3", ... container="MY_CONTAINER", ... key_environ="AWS_ACCESS_KEY_ID", ... secret_environ="AWS_SECRET_ACCESS_KEY", ... ) >>> provider = provider_hparams.initialize_object() >>> provider <composer.utils.object_store.ObjectStoreProvider object at ...>
- Parameters
provider (str) โ
Cloud provider to use.
See
ObjectStoreProvider
for documentation.container (str) โ The name of the container (i.e. bucket) to use.
key_environ (str, optional) โ
The name of an environment variable containing the API key or username to use to connect to the provider. If no key is required, then set this field to
None
. (default:None
)For security reasons, composer requires that the key be specified via an environment variable. For example, if your key is an environment variable called
OBJECT_STORE_KEY
that is set toMY_KEY
, then you should set this parameter equal toOBJECT_STORE_KEY
. Composer will read the key like this:>>> import os >>> params = ObjectStoreProviderHparams(key_environ="OBJECT_STORE_KEY") >>> key = os.environ[params.key_environ] >>> key 'MY_KEY'
secret_environ (str, optional) โ
The name of an environment variable containing the API secret or password to use for the provider. If no secret is required, then set this field to
None
. (default:None
)For security reasons, composer requires that the secret be specified via an environment variable. For example, if your secret is an environment variable called
OBJECT_STORE_SECRET
that is set toMY_SECRET
, then you should set this parameter equal toOBJECT_STORE_SECRET
. Composer will read the secret like this:>>> import os >>> params = ObjectStoreProviderHparams(secret_environ="OBJECT_STORE_SECRET") >>> secret = os.environ[params.secret_environ] >>> secret 'MY_SECRET'
region (str, optional) โ Cloud region to use for the cloud provider. Most providers do not require the region to be specified. (default:
None
)host (str, optional) โ Override the hostname for the cloud provider. (default:
None
)port (int, optional) โ Override the port for the cloud provider. (default:
None
)extra_init_kwargs (Dict[str, Any], optional) โ
Extra keyword arguments to pass into the constructor for the specified provider. (default:
None
, which is equivalent to an empty dictionary)See also
RunDirectoryUploader API#
- class composer.callbacks.RunDirectoryUploader(object_store_provider_hparams, object_name_prefix=None, num_concurrent_uploads=4, upload_staging_folder=None, use_procs=True, upload_every_n_batches=100)[source]
Callback to upload the run directory to a blob store.
This callback checks the run directory for new or modified files at the end of every epoch, and after every
upload_every_n_batches
batches. This callback detects new or modified files based on the file modification timestamp. Only files that have a newer last modified timestamp since the last upload will be uploaded.- Example
>>> osphparams = ObjectStoreProviderHparams( ... provider="s3", ... container="run-dir-test", ... key_environ="OBJECT_STORE_KEY", ... secret_environ="OBJECT_STORE_SECRET", ... region="us-west-2", ... ) >>> # construct trainer object with this callback >>> run_directory_uploader = RunDirectoryUploader(osphparams) >>> trainer = Trainer( ... model=model, ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... optimizers=optimizer, ... max_duration="1ep", ... callbacks=[run_directory_uploader], ... ) >>> # trainer will run this callback whenever the EPOCH_END >>> # is triggered, like this: >>> _ = trainer.engine.run_event(Event.EPOCH_END)
Note
This callback blocks the training loop to copy files from the
run_directory
to theupload_staging_folder
and to queue these files to the upload queues of the workers. Actual upload happens in the background. While all uploads happen in the background, here are some additional tips for minimizing the performance impact:Ensure that
upload_every_n_batches
is sufficiently infrequent as to limit when the blocking scans of the run directory and copies of modified files. However, do not make it too infrequent in case if the training process unexpectedly dies, since data written after the last upload may be lost.Set
use_procs=True
(the default) to use background processes, instead of threads, to perform the file uploads. Processes are recommended to ensure that the GIL is not blocking the training loop when performance CPU operations on uploaded files (e.g. computing and comparing checksums). Network I/O happens always occurs in the background.Provide a RAM disk path for the
upload_staging_folder
parameter. Copying files to stage on RAM will be faster than writing to disk. However, you must have sufficient excess RAM on your system, or you may experience OutOfMemory errors.
- Parameters
object_store_provider_hparams (ObjectStoreProviderHparams) โ
ObjectStoreProvider hyperparameters object
See
ObjectStoreProviderHparams
for documentation.object_name_prefix (str, optional) โ
A prefix to prepend to all object keys. An objectโs key is this prefix combined with its path relative to the run directory. If the container prefix is non-empty, a trailing slash (โ/โ) will be added if necessary. If not specified, then the prefix defaults to the run directory. To disable prefixing, set to the empty string.
For example, if
object_name_prefix = 'foo'
and there is a file in the run directory namedbar
, then that file would be uploaded tofoo/bar
in the container.num_concurrent_uploads (int, optional) โ Maximum number of concurrent uploads. Defaults to 4.
upload_staging_folder (str, optional) โ A folder to use for staging uploads. If not specified, defaults to using a
TemporaryDirectory()
.use_procs (bool, optional) โ Whether to perform file uploads in background processes (as opposed to threads). Defaults to True.
upload_every_n_batches (int, optional) โ Interval at which to scan the run directory for changes and to queue uploads of files. In addition, uploads are always queued at the end of the epoch. Defaults to every 100 batches.