composer.algorithms.functional.selective_backprop
- composer.algorithms.functional.selective_backprop(X: torch.Tensor, y: torch.Tensor, model: torch.nn.modules.module.Module, loss_fun: Callable, keep: float, scale_factor: float = 1) Tuple[torch.Tensor, torch.Tensor] [source]
Select a subset of the batch on which to learn as per (Jiang et al. 2019)
Selective Backprop (SB) prunes minibatches according to the difficulty of the individual training examples and only computes weight gradients over the selected subset. This reduces iteration time and speeds up training. The fraction of the minibatch that is kept for gradient computation is specified by the argument
0 <= keep <= 1
.To speed up SB’s selection forward pass, the argument
scale_factor
can be used to spatially downsample input tensors. The full-sized inputs will still be used for the weight gradient computation.- Parameters
X – Input tensor to prune
y – Output tensor to prune
model – Model with which to predict outputs
loss_fun – Loss function of the form
loss(outputs, targets, reduction='none')
. The function must take the keyword argumentreduction='none'
to ensure that per-sample losses are returned.keep – Fraction of examples in the batch to keep
scale_factor – Multiplier between 0 and 1 for spatial size. Downsampling requires the input tensor to be at least 3D.
- Returns
(torch.Tensor, torch.Tensor) – The pruned batch of inputs and targets
- Raises
ValueError – If
scale_factor > 1
TypeError – If
loss_fun > 1
has the wrong signature or is not callable
Note: This function runs an extra forward pass through the model on the batch of data. If you are using a non-default precision, ensure that this forward pass runs in your desired precision. For example:
with torch.cuda.amp.autocast(True): X_new, y_new = selective_backprop(X, y, model, loss_fun, keep, scale_factor)