composer.trainer.trainer#

Train models!

The trainer supports models with ComposerModel instances. The Trainer is highly customizable and can support a wide variety of workloads.

Example

Train a model and save a checkpoint:

import os
from composer import Trainer

### Create a trainer
trainer = Trainer(model=model,
                  train_dataloader=train_dataloader,
                  max_duration="1ep",
                  eval_dataloader=eval_dataloader,
                  optimizers=optimizer,
                  schedulers=scheduler,
                  device="cpu",
                  validate_every_n_epochs=1,
                  save_folder="checkpoints",
                  save_interval="1ep")

### Fit and run evaluation for 1 epoch.
### Save a checkpoint after 1 epoch as specified during trainer creation.
trainer.fit()

Load the checkpoint and resume training:

### Get the saved checkpoint folder
### By default, the checkpoint folder is of the form runs/<timestamp>/rank_0/checkpoints
### Alternatively, if you set the run directory environment variable as follows:
### os.environ["COMPOSER_RUN_DIRECTORY"] = "my_run_directory", then the checkpoint path
### will be of the form my_run_directory/rank_0/checkpoints
checkpoint_folder = trainer.checkpoint_folder

### If the save_interval was in terms of epochs like above then by default,
### checkpoint filenames are of the form "ep{EPOCH_NUMBER}.pt".
checkpoint_path = os.path.join(checkpoint_folder, "ep1.pt")

### Create a new trainer with the load_path_format argument set to the checkpoint path.
### This will automatically load the checkpoint on trainer creation.
trainer = Trainer(model=model,
                  train_dataloader=train_dataloader,
                  max_duration="2ep",
                  eval_dataloader=eval_dataloader,
                  optimizers=optimizer,
                  schedulers=scheduler,
                  device="cpu",
                  validate_every_n_epochs=1,
                  load_path_format=checkpoint_path)

### Continue training and running evaluation where the previous trainer left off
### until the new max_duration is reached.
### In this case it will be one additional epoch to reach 2 epochs total.
trainer.fit()

Classes

Trainer

Trainer for training a models with Composer algorithms.

class composer.trainer.trainer.Trainer(*, model, train_dataloader, max_duration, eval_dataloader=None, algorithms=None, optimizers=None, schedulers=None, device=None, grad_accum=1, grad_clip_norm=None, validate_every_n_batches=- 1, validate_every_n_epochs=1, compute_training_metrics=False, precision=Precision.FP32, scale_schedule_ratio=1.0, step_schedulers_every_batch=None, dist_timeout=300.0, ddp_sync_strategy=None, seed=None, deterministic_mode=False, loggers=None, callbacks=(), load_path_format=None, load_object_store=None, load_weights_only=False, load_strict=False, load_chunk_size=1048576, load_progress_bar=True, save_folder=None, save_name_format='ep{epoch}-ba{batch}-rank{rank}', save_latest_format='latest-rank{rank}', save_overwrite=False, save_interval='1ep', save_weights_only=False, train_subset_num_batches=None, eval_subset_num_batches=None, deepspeed_config=False, profiler_trace_file=None, prof_event_handlers=(), prof_skip_first=0, prof_wait=0, prof_warmup=1, prof_active=4, prof_repeat=1, sys_prof_cpu=True, sys_prof_memory=False, sys_prof_disk=False, sys_prof_net=False, sys_prof_stats_thread_interval_seconds=0.5, torch_profiler_trace_dir=None, torch_prof_use_gzip=False, torch_prof_record_shapes=False, torch_prof_profile_memory=True, torch_prof_with_stack=False, torch_prof_with_flops=True)[source]#

Trainer for training a models with Composer algorithms. See the Trainer guide for more information.

Parameters
  • model (ComposerModel) โ€“

    The model to train. Can be user-defined or one of the models included with Composer.

    See also

    composer.models for models built into Composer.

  • train_dataloader (DataLoader, DataSpec, or dict) โ€“

    The DataLoader, DataSpec, or dict of DataSpec kwargs for the training data. In order to specify custom preprocessing steps on each data batch, specify a DataSpec instead of a DataLoader.

    Note

    The train_dataloader should yield per-rank batches. Each per-rank batch will then be further divided based on the grad_accum parameter. For example, if the desired optimization batch size is 2048 and training is happening across 8 GPUs, then each train_dataloader should yield a batch of size 2048 / 8 = 256. If grad_accum = 2, then the per-rank batch will be divided into microbatches of size 256 / 2 = 128.

  • max_duration (int, str, or Time) โ€“ The maximum duration to train. Can be an integer, which will be interpreted to be epochs, a str (e.g. 1ep, or 10ba), or a Time object.

  • eval_dataloader (DataLoader, DataSpec, or Evaluators, optional) โ€“ The DataLoader, DataSpec, or Evaluators for the evaluation data. In order to evaluate one or more specific metrics across one or more datasets, pass in an Evaluator. If a DataSpec or DataLoader is passed in, then all metrics returned by model.metrics() will be used during evaluation. None results in no evaluation. (default: None)

  • algorithms (List[Algorithm], optional) โ€“

    The algorithms to use during training. If None, then no algorithms will be used. (default: None)

    See also

    composer.algorithms for the different algorithms built into Composer.

  • optimizers (Optimizers, optional) โ€“

    The optimizer. If None, will be set to DecoupledSGDW(model.parameters(), lr=0.1). (default: None)

    See also

    composer.optim for the different optimizers built into Composer.

  • schedulers (Schedulers, optional) โ€“

    The learning rate schedulers. If [] or None, will be set to [constant_scheduler]. (default: None).

    See also

    composer.optim.scheduler for the different schedulers built into Composer.

  • device (str or Device, optional) โ€“ The device to use for training. Either cpu or gpu. (default: cpu)

  • grad_accum (int, optional) โ€“

    The number of microbatches to split a per-device batch into. Gradients are summed over the microbatches per device. (default: 1)

    Note

    This is implemented by taking the batch yielded by the train_dataloader and splitting it into grad_accum sections. Each section is of size train_dataloader // grad_accum. If the batch size of the dataloader is not divisible by grad_accum, then the last section will be of size batch_size % grad_accum.

  • grad_clip_norm (float, optional) โ€“ The norm to clip gradient magnitudes to. Set to None for no gradient clipping. (default: None)

  • validate_every_n_batches (int, optional) โ€“ Compute metrics on evaluation data every N batches. Set to -1 to never validate on a batchwise frequency. (default: -1)

  • validate_every_n_epochs (int, optional) โ€“ Compute metrics on evaluation data every N epochs. Set to -1 to never validate on a epochwise frequency. (default: 1)

  • compute_training_metrics (bool, optional) โ€“ True to compute metrics on training data and False to not. (default: False)

  • precision (str or Precision, optional) โ€“

    Numerical precision to use for training. One of fp32, fp16 or amp (recommended). (default: Precision.FP32)

    Note

    fp16 only works if deepspeed_config is also provided.

  • scale_schedule_ratio (float, optional) โ€“

    Ratio by which to scale the training duration and learning rate schedules. E.g., 0.5 makes the schedule take half as many epochs and 2.0 makes it take twice as many epochs. 1.0 means no change. (default: 1.0)

    Note

    Training for less time, while rescaling the learning rate schedule, is a strong baseline approach to speeding up training. E.g., training for half duration often yields minor accuracy degradation, provided that the learning rate schedule is also rescaled to take half as long.

    To see the difference, consider training for half as long using a cosine annealing learning rate schedule. If the schedule is not rescaled, training ends while the learning rate is still ~0.5 of the initial LR. If the schedule is rescaled with scale_schedule_ratio, the LR schedule would finish the entire cosine curve, ending with a learning rate near zero.

  • step_schedulers_every_batch (bool, optional) โ€“ By default, native PyTorch schedulers are updated every epoch, while Composer Schedulers are updated every step. Setting this to True will force schedulers to be stepped every batch, while False means schedulers stepped every epoch. None indicates the default behavior. (default: None)

  • dist_timeout (float, optional) โ€“ Timeout, in seconds, for initializing the distributed process group. (default: 15.0)

  • ddp_sync_strategy (str or DDPSyncStrategy, optional) โ€“ The strategy to use for synchronizing gradients. Leave unset to let the trainer auto-configure this. See DDPSyncStrategy for more details.

  • seed (int, optional) โ€“

    The seed used in randomization. If None, then a random seed will be created. (default: None)

    Note

    In order to get reproducible results, call the seed_all() function at the start of your script with the seed passed to the trainer. This will ensure any initialization done before the trainer init (ex. model weight initialization) also uses the provided seed.

    See also

    composer.utils.reproducibility for more details on reproducibility.

  • deterministic_mode (bool, optional) โ€“

    Run the model deterministically. (default: False)

    Note

    This is an experimental feature. Performance degradations expected. Certain Torch modules may not have deterministic implementations, which will result in a crash.

    Note

    In order to get reproducible results, call the configure_deterministic_mode() function at the start of your script. This will ensure any initialization done before the trainer init also runs deterministically.

    See also

    composer.utils.reproducibility for more details on reproducibility.

  • loggers (Sequence[LoggerCallback], optional) โ€“

    The destinations to log training information to. If None, will be set to [TQDMLogger()]. (default: None)

    See also

    composer.loggers for the different loggers built into Composer.

  • callbacks (Sequence[Callback], optional) โ€“

    The callbacks to run during training. If None, then no callbacks will be run. (default: None).

    See also

    composer.callbacks for the different callbacks built into Composer.

  • load_path_format (str, optional) โ€“

    The path format string to an existing checkpoint file.

    It can be a path to a file on the local disk, a URL, or if load_object_store is set, the object name for a checkpoint in a cloud bucket.

    When using Deepspeed ZeRO, checkpoints are shareded by rank. Instead of hard-coding the rank in the path_format, use the following format variables:

    Variable

    Description

    {rank}

    The global rank, as returned by get_global_rank().

    {local_rank}

    The local rank of the process, as returned by get_local_rank().

    {node_rank}

    The node rank, as returned by get_node_rank().

    For example, suppose that checkpoints are stored in the following structure:

    my_model/ep1-rank0.tar
    my_model/ep1-rank1.tar
    my_model/ep1-rank2.tar
    ...
    

    Then, load_path_format should be set to my_model/ep1-rank{rank}.tar, and all ranks will load the correct state.

    If None then no checkpoint will be loaded. (default: None)

  • load_object_store (ObjectStoreProvider, optional) โ€“

    If the load_path_format is in an object store (i.e. AWS S3 or Google Cloud Storage), an instance of ObjectStoreProvider which will be used to retreive the checkpoint. Otherwise, if the checkpoint is a local filepath, set to None. Ignored if load_path_format is None. (default: None)

    Example:

    from composer import Trainer
    from composer.utils import ObjectStoreProvider
    
    # Create the object store provider with the specified credentials
    creds = {"key": "object_store_key",
             "secret": "object_store_secret"}
    store = ObjectStoreProvider(provider="s3",
                                container="my_container",
                                provider_init_kwargs=creds)
    
    checkpoint_path = "/path_to_the_checkpoint_in_object_store"
    
    # Create a trainer which will load a checkpoint from the specified object store
    trainer = Trainer(model=model,
                      train_dataloader=train_dataloader,
                      max_duration="10ep",
                      eval_dataloader=eval_dataloader,
                      optimizers=optimizer,
                      schedulers=scheduler,
                      device="cpu",
                      validate_every_n_epochs=1,
                      load_path_format=checkpoint_path,
                      load_object_store=store)
    

  • load_weights_only (bool, optional) โ€“ Whether or not to only restore the weights from the checkpoint without restoring the associated state. Ignored if load_path_format is None. (default: False)

  • load_strict (bool, optional) โ€“ Ensure that the set of weights in the checkpoint and model must exactly match. Ignored if load_path_format is None. (default: False)

  • load_chunk_size (int, optional) โ€“ Chunk size (in bytes) to use when downloading checkpoints. Ignored if load_path_format is either None or a local file path. (default: 1,048,675)

  • load_progress_bar (bool, optional) โ€“ Display the progress bar for downloading the checkpoint. Ignored if load_path_format is either None or a local file path. (default: True)

  • save_folder (str, optional) โ€“

    Folder where checkpoints are saved. If None, checkpoints will not be saved by default. .. seealso:: CheckpointSaver

    Note

    For fine-grained control on checkpoint saving (e.g. to save different types of checkpoints at different intervals), leave this parameter as None, and instead pass instance(s) of CheckpointSaver directly as callbacks.

    (default: None)

  • save_name_format (str, optional) โ€“

    A format string describing how to name checkpoints. This parameter has no effect if save_folder is None. (default: "ep{epoch}-ba{batch}-rank{rank}")

    See also

    CheckpointSaver

  • save_latest_format (str, optional) โ€“

    A format string for the name of a symlink (relative to checkpoint_folder) that points to the last saved checkpoint. This parameter has no effect if save_folder is None. To disable symlinking, set to None. (default: "latest-rank{rank}")

    See also

    CheckpointSaver

  • save_overwrite (bool, optional) โ€“

    Whether existing checkpoints should be overridden. This parameter has no effect if save_folder is None. (default: False)

    See also

    CheckpointSaver

  • save_interval (Time | str | int | (State, Event) -> bool) โ€“

    A Time, time-string, integer (in epochs), or a function that takes (state, event) and returns a boolean whether a checkpoint should be saved. This parameter has no effect if save_folder is None. (default: '1ep')

    See also

    CheckpointSaver

  • save_weights_only (bool, optional) โ€“

    Whether to save only the model weights instead of the entire training state. This parameter has no effect if save_folder is None. (default: False)

    See also

    CheckpointSaver

  • train_subset_num_batches (int, optional) โ€“ If specified, finish every epoch early after training on this many batches. This parameter has no effect if it is greater than len(train_dataloader). If None, then the entire dataloader will be iterated over. (default: None)

  • eval_subset_num_batches (int, optional) โ€“ If specified, evaluate on this many batches. This parameter has no effect if it is greater than len(eval_dataloader). If None, then the entire dataloader will be iterated over. (default: None)

  • deepspeed_config (bool or Dict[str, Any], optional) โ€“ Configuration for DeepSpeed, formatted as a JSON according to DeepSpeedโ€™s documentation. If True is provided, the trainer will initialize the DeepSpeed engine with an empty config {}. If False is provided, deepspeed will not be used. (default: False)

  • profiler_trace_file (str, optional) โ€“

    Name of the trace file, relative to the run directory. Setting this parameter activates the profiler. (default: None).

    See also

    composer.profiler for more details on profiling with the trainer.

  • prof_event_handlers (List[ProfilerEventHandler], optional) โ€“ Trace event handler. Ignored if profiler_trace_file is not specified. (default: [JSONTraceHandler()]).

  • prof_skip_first (int, optional) โ€“ Number of batches to skip at epoch start. Ignored if profiler_trace_file is not specified. (default: 0).

  • prof_wait (int, optional) โ€“ Number of batches to skip at the beginning of each cycle. Ignored if profiler_trace_file is not specified. (default: 0).

  • prof_warmup (int, optional) โ€“ Number of warmup batches in a cycle. Ignored if profiler_trace_file is not specified. (default: 1).

  • prof_active (int, optional) โ€“ Number of batches to profile in a cycle. Ignored if profiler_trace_file is not specified. (default: 4).

  • prof_repeat (int, optional) โ€“ Maximum number of profiling cycle repetitions per epoch (0 for no maximum). Ignored if profiler_trace_file is not specified. (default: 1).

  • sys_prof_cpu (bool, optional) โ€“ Whether to record cpu statistics. Ignored if profiler_trace_file is not specified. (default: True).

  • sys_prof_memory (bool, optional) โ€“ Whether to record memory statistics. Ignored if profiler_trace_file is not specified. (default: False).

  • sys_prof_disk (bool, optional) โ€“ Whether to record disk statistics. Ignored if profiler_trace_file is not specified. (default: False).

  • sys_prof_net (bool, optional) โ€“ Whether to record network statistics. Ignored if profiler_trace_file is not specified. (default: False).

  • sys_prof_stats_thread_interval_seconds (float, optional) โ€“ Interval to record stats, in seconds. Ignored if profiler_trace_file is not specified. (default: 0.5).

  • torch_profiler_trace_dir (str, optional) โ€“ Directory to store trace results relative to the run directory. Must be specified to activate the Torch profiler. Ignored if profiler_trace_file is not specified. See profiler. (default: None).

  • torch_prof_use_gzip (bool) โ€“ Whether to use gzip for trace. Ignored if torch_profiler_trace_dir and profiler_trace_file are not specified. (default: False).

  • torch_prof_record_shapes (bool, optional) โ€“ Whether to record tensor shapes. Ignored if torch_profiler_trace_dir and profiler_trace_file are not specified. (default: False).

  • torch_prof_profile_memory (bool, optional) โ€“ Track tensor memory allocations and frees. Ignored if torch_profiler_trace_dir and profiler_trace_file are not specified. (default: True).

  • torch_prof_with_stack (bool, optional) โ€“ Record stack info. Ignored if torch_profiler_trace_dir and profiler_trace_file are not specified. (default: False).

  • torch_prof_with_flops (bool, optional) โ€“ Estimate flops for operators. Ignored if torch_profiler_trace_dir and profiler_trace_file are not specified. (default: True).

state#

The State object used to store training state.

Type

State

evaluators#

The Evaluator objects to use for validation during training.

Type

List[Evaluator]

logger#

The Logger used for logging.

Type

Logger

engine#

The Engine used for running callbacks and algorithms.

Type

Engine

property checkpoint_folder#

The folder in which checkpoints are stored.

Returns
  • Optional[str] โ€“ The checkpoint folder, or None, if checkpoints were not saved.

  • If an absolute path was specified for ``save_folder`` upon trainer instantiation, then that path will be

  • used. Otherwise, this folder is relative to the :mod:`~composer.utils.run_directory` of the training run

  • (e.g. ``{run_directory}/{save_folder}``). If no run directory is provided, then by default, it is of the

  • form ``runs/<timestamp>/rank_<GLOBAL_RANK>/<save_folder>`` where ``timestamp`` is the start time of the

  • run in iso-format, ``GLOBAL_RANK`` is the global rank of the process, and ``save_folder`` is the

  • ``save_folder`` argument provided upon construction.

property deepspeed_enabled#

True if DeepSpeed is being used for training and False otherwise.

eval(is_batch)[source]#

Evaluate the model on the provided evaluation data and log appropriate metrics.

Parameters

is_batch (bool) โ€“ True to log metrics with LogLevel.BATCH and False to log metrics with LogLevel.EPOCH.

fit()[source]#

Train and evaluate the model on the provided data.

save_checkpoint(name_format='ep{epoch}-ba{batch}-rank{rank}', *, weights_only=False)[source]#

Checkpoint the training State.

Parameters
Returns

List[pathlib.Path] โ€“ See save_checkpoint().

property saved_checkpoints#

The times and paths to checkpoint files saved across all ranks during training.

Returns

Dict[Timestamp, List[str]] โ€“ A dictionary mapping a save Timestamp. to a list of filepaths, indexed by global rank, corresponding to the checkpoints saved at that time.

Note

When using DeepSpeed, the index of a filepath corresponds to the global rank of the process that wrote that file. These filepaths are valid only on the global rankโ€™s node. Otherwise, when not using DeepSpeed, this list will contain only one filepath since only rank zero saves checkpoints.