load_checkpoint#
- composer.utils.load_checkpoint(path, state, object_store=None, load_weights_only=False, strict_model_weights=False, progress_bar=True, ignore_keys=None)[source]#
Load a checkpoint from a local file, URI, or cloud object store into
state
.- Parameters
path (str) โ
The path format string to an existing checkpoint file.
It can be a path to a file on the local disk, a URL, or if
object_store
is set, the object name for a checkpoint in a cloud bucket.When using Deepspeed ZeRO, checkpoints are shareded by rank. Instead of hard-coding the rank in the
path
, use the following format variables:Variable
Description
{rank}
The global rank, as returned by
get_global_rank()
.{local_rank}
The local rank of the process, as returned by
get_local_rank()
.{node_rank}
The node rank, as returned by
get_node_rank()
.For example, suppose that checkpoints are stored in the following structure:
my_model/ep1-rank0.tar my_model/ep1-rank1.tar my_model/ep1-rank2.tar ...
Then,
path
should be set tomy_model/ep1-rank{rank}.tar
, and all ranks will load the correct state.state (State) โ The
State
to load the checkpoint into.object_store (Union[ObjectStore, LoggerDestination], optional) โ If the
path
is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance ofObjectStore
orLoggerDestination
which will be used to retreive the checkpoint. Otherwise, if the checkpoint is a local filepath, set toNone
. (default:None
)load_weights_only (bool, optional) โ Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default:
False
)strict_model_weights (bool, optional) โ Whether or not to force that the checkpointed weights must exactly match the model weights. (default:
False
)progress_bar (bool, optional) โ Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default:
True
)ignore_keys (List[str] | (Dict) -> None, optional) โ
A list of paths for the
state_dict
of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index intostate_dict
joined together with / as a seperator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). Seecomposer.core.state
for the structure of state_dict.Example 1:
ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]
would ignore layer 1 weights and bias.Example 2:
ignore_keys = ["state/model/*"]
would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.Example 3:
ignore_keys = ["state/model/layer*.weights"]
would ignore all weights in the model.Example 4:
ignore_keys = ["state/rank_zero_seed", "rng"]
would reset all randomness when loading the checkpoint.If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.
(default:
None
)
- Returns
Optional[List[Dict[str, Any]]] โ The RNG state dicts, indexed by global rank, if
load_weights_only
is not None. Otherwise, None.