load_checkpoint#

composer.utils.load_checkpoint(path, state, logger, object_store=None, load_weights_only=False, strict_model_weights=True, progress_bar=True, ignore_keys=None, exclude_algorithms=None, algorithm_passes=None)[source]#

Load a checkpoint from a local file, URI, or cloud object store into state.

Parameters
  • path (str) โ€“

    The path format string to an existing checkpoint file.

    It can be a path to a file on the local disk, a URL, or if object_store is set, the object name for a checkpoint in a cloud bucket.

    When using Deepspeed ZeRO, checkpoints are sharded by rank. Instead of hard-coding the rank in the path, use the following format variables:

    Variable

    Description

    {rank}

    The global rank, as returned by get_global_rank().

    {local_rank}

    The local rank of the process, as returned by get_local_rank().

    {node_rank}

    The node rank, as returned by get_node_rank().

    For example, suppose that checkpoints are stored in the following structure:

    my_model/ep1-rank0.tar
    my_model/ep1-rank1.tar
    my_model/ep1-rank2.tar
    ...
    

    Then, path should be set to my_model/ep1-rank{rank}.tar, and all ranks will load the correct state.

  • state (State) โ€“ The State to load the checkpoint into.

  • logger (Logger) โ€“ The Logger to log any information.

  • object_store (Union[ObjectStore, LoggerDestination], optional) โ€“ If the path is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance of ObjectStore or LoggerDestination which will be used to retrieve the checkpoint. Otherwise, if the checkpoint is a local filepath, set to None. (default: None)

  • load_weights_only (bool, optional) โ€“ Whether or not to only restore the model weights from the checkpoint without restoring the associated state. (default: False)

  • strict_model_weights (bool, optional) โ€“ Whether or not to force that the checkpointed weights must exactly match the model weights. (default: True)

  • progress_bar (bool, optional) โ€“ Whether or not to show a progress bar when downloading checkpoints. Ignored if the checkpoint is a local file path. (default: True)

  • ignore_keys (list[str] | (dict) -> None, optional) โ€“

    A list of paths for the state_dict of the checkpoint, which, when provided, will be ignored from the state_dict before a checkpoint is loaded. Each path is a list of strings specifying the keys to index into state_dict joined together with / as a separator (as PyTorch uses . in parameter names). If a prefix is provided, all children are also ignored (see Example 2). See composer.core.state for the structure of state_dict.

    Example 1: ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"] would ignore layer 1 weights and bias.

    Example 2: ignore_keys = ["state/model/*"] would ignore the entire model, which would have the same effect as the previous example if there was only 1 layer.

    Example 3: ignore_keys = ["state/model/layer*.weights"] would ignore all weights in the model.

    Example 4: ignore_keys = ["state/rank_zero_seed", "rng"] would reset all randomness when loading the checkpoint.

    If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify the state_dict before it is loaded.

    (default: None)

  • exclude_algorithms (list[str], optional) โ€“

    A list of algorithm names to exclude from loading. By default, algorithms with required_on_load=True which were enabled when training the loaded checkpoint are automatically applied unless they conflict with a user specified algorithm. These algorithms often change the model, and not applying them could result in certain layers not having weights loaded.

    Example 1: exclude_algorithms = ["BlurPool"] would exclude BlurPool from loading.

    Example 2: exclude_algorithms = ["FusedLayerNorm", "Alibi"] would exclude FusedLayerNorm and Alibi from loading.

    (default: None)

  • algorithm_passes (list[AlgorithmPass], optional) โ€“ A list of algorithm passes to apply to autoloaded algorithms to sort them into the correct order. (default: None)

Returns

Optional[list[dict[str, Any]]] โ€“ The RNG state dicts, indexed by global rank, if load_weights_only is not None. Otherwise, None.