๐Ÿซ€ Squeeze-and-Excitation#

Tags: ConvNets, Decreased GPU Throughput, Increased Accuracy, Method, Capacity

TL;DR#

Adds a channel-wise attention operator in CNNs. Attention coefficients are produced by a small, trainable MLP that uses the channelsโ€™ globally pooled activations as input.

Squeeze-and-Excitation

After an activation tensor \(\mathbf{X}\) is passed through Conv2d \(\mathbf{F}_{tr}\) to yield a new tensor \(\mathbf{U}\), a Squeeze-Excitation (SE) module scales the channels in a data-dependent manner. The scales are produced by a single-hidden-layer fully-connected network whose input is the global-averaged-pooled \(\mathbf{U}\). This can be seen as a channel-wise attention mechanism.

Attribution#

Squeeze-and-Excitation Networks by Jie Hu, Li Shen, and Gang Sun (2018).

Code and Hyperparameters#

  • latent_channels - Number of channels to use in the hidden layer of MLP that computes channel attention coefficients.

  • min_channels - The minimum number of output channels in a Conv2d for an SE module to be added afterward.

Applicable Settings#

Applicable to convolutional neural networks. Currently only implemented for CNNs with 2d inputs (e.g., images).

Example Effects#

0.5-1.5% accuracy gain, roughly 25% slowdown of the model. E.g., weโ€™ve seen an accuracy change from 76.1 to 77.2% on ImageNet with ResNet-50, in exchange for a training throughput decrease from 4500 samples/sec to 3500 samples/sec on eight RTX 3080 GPUs.

Implementation Details#

Squeeze-Excitation blocks apply channel-wise attention to an activation tensor \(\mathbf{X}\). The attention coefficients are produced by a single-hidden-layer MLP (i.e., fully-connected network). This network takes in the result of global average pooling \(\mathbf{X}\) as its input vector. In short, the average activations within each channel are used to produce scalar multipliers for each channel.

In order to be architecture-agnostic, our implementation applies the SE attention mechanism after individual conv2d modules, rather than at particular points in particular networks. This results in more SE modules being present than in the original paper.

Our implementation also allows applying the SE module after only certain conv2d modules, based on their channel count (see hyperparameter discussion).

Suggested Hyperparameters#

  • latent_channels - 64 yielded the best speed-accuracy tradeoffs in our ResNet experiments. The original paper expressed this as a โ€œreduction ratioโ€ \(r\) that makes the MLP latent channel count a fraction of the SE blockโ€™s input channel count. We also support specifying latent_channels as a fraction of the input channel count, although weโ€™ve found that it tends to yield a worse speed vs accuracy tradeoff.

  • min_channels - For typical CNNs that have lower channel count at higher resolution, this can be used to control where in the network to start applying SE blocks. Ops with higher channel counts take longer to compute relative to the time taken by the SE block. An appropriate value is architecture-dependent, but we weakly suggest setting this to 128 if the architecture in question has modules with at least this many channels.

Considerations#

This method tends to consistently improve the accuracy of CNNs both in absolute terms and when controlling for training and inference time. This may come at the cost of a roughly 20% increase in inference latency, depending on the architecture and inference hardware.

Composability#

Because SE modules slow down the model, they compose well with methods that make the data loader slower (e.g., RandAugment) or that speed up each training step (e.g., Selective Backprop). In the former case, the slower model allows more time for the data loader to run. In the latter case, the initial slowdown allows techniques that accelerate the forward and backward passes to have a greater effect before they become limited by the data loaderโ€™s speed.


Code#

class composer.algorithms.squeeze_excite.SqueezeExcite(latent_channels=64, min_channels=128)[source]

Adds Squeeze-and-Excitation blocks (Hu et al, 2019) after the Conv2d modules in a neural network.

Runs on INIT. See SqueezeExcite2d for more information.

Parameters
  • latent_channels โ€“ Dimensionality of the hidden layer within the added MLP. If less than 1, interpreted as a fraction of the number of output channels in the Conv2d immediately preceding each Squeeze-and-Excitation block.

  • min_channels โ€“ An SE block is added after a Conv2d module conv only if min(conv.in_channels, conv.out_channels) >= min_channels. For models that reduce spatial size and increase channel count deeper in the network, this parameter can be used to only add SE blocks deeper in the network. This may be desirable because SE blocks add less overhead when their inputs have smaller spatial size.

apply(event, state, logger)[source]

Apply the Squeeze-and-Excitation layer replacement.

Parameters
  • event (Event) โ€“ the current event

  • state (State) โ€“ the current trainer state

  • logger (Logger) โ€“ the training logger

match(event, state)[source]

Runs on INIT

Parameters
  • event (Event) โ€“ The current event.

  • state (State) โ€“ The current state.

Returns

bool โ€“ True if this algorithm should run no

class composer.algorithms.squeeze_excite.SqueezeExcite2d(num_features, latent_channels=0.125)[source]

Squeeze-and-Excitation block from (Hu et al, 2019)

This block applies global average pooling to the input, feeds the resulting vector to a single-hidden-layer fully-connected network (MLP), and uses the output of this MLP as attention coefficients to rescale the input. This allows the network to take into account global information about each input, as opposed to only local receptive fields like in a convolutional layer.

Parameters
  • num_features (int) โ€“ Number of features or channels in the input

  • latent_channels (float, optional) โ€“ Dimensionality of the hidden layer within the added MLP. If less than 1, interpreted as a fraction of num_features.

class composer.algorithms.squeeze_excite.SqueezeExciteConv2d(*args, latent_channels=0.125, conv=None, **kwargs)[source]

Helper class used to add a SqueezeExcite2d module after a Conv2d.

composer.algorithms.squeeze_excite.apply_squeeze_excite(model, latent_channels=64, min_channels=128, optimizers=None)[source]

Adds Squeeze-and-Excitation blocks (Hu et al, 2019) after Conv2d layers.

A Squeeze-and-Excitation block applies global average pooling to the input, feeds the resulting vector to a single-hidden-layer fully-connected network (MLP), and uses the output of this MLP as attention coefficients to rescale the input. This allows the network to take into account global information about each input, as opposed to only local receptive fields like in a convolutional layer.

Parameters
  • latent_channels (float, optional) โ€“ Dimensionality of the hidden layer within the added MLP. If less than 1, interpreted as a fraction of the number of output channels in the Conv2d immediately preceding each Squeeze-and-Excitation block.

  • optimizers (Optimizers, optional) โ€“

    Existing optimizers bound to model.parameters(). All optimizers that have already been constructed with model.parameters() must be specified here so they will optimize the correct parameters.

    If the optimizer(s) are constructed after calling this function, then it is safe to omit this parameter. These optimizers will see the correct model parameters.

Returns

The modified model

Example

import composer.functional as cf
from torchvision import models
model = models.resnet50()
cf.apply_stochastic_depth(model, target_layer_name='ResNetBottleneck')