BlurPool

BlurPool Antialiasing

How various original ops (top row) are replaced with corresponding BlurPool ops (bottom row) in the original paper. In each case, a low-pass filter is applied before the spatial downsampling to avoid aliasing.

Tags: Vision, Decreased GPU Throughput, Increased Accuracy, Method, Regularization

TL;DR

Increases accuracy at nearly the same speed by applying a spatial low-pass filter before the pool in max pooling and whenever using a strided convolution.## Graphic

BlurPool Antialiasing

How various original ops (top row) are replaced with corresponding BlurPool ops (bottom row) in the original paper. In each case, a low-pass filter is applied before the spatial downsampling to avoid aliasing.

Attribution

Making Convolutional Networks Shift-Invariant Again. by Richard Zhang (2019).

Implementation by Richard Zhang (GitHub)

Project Website by Richard Zhang

Code and Hyperparameters

Link to code in Mosaic

  • replace_convs - replace strided torch.nn.Conv2d modules within the module with anti-aliased versions

  • replace_maxpools - replace torch.nn.MaxPool2d modules with anti-aliased versions

  • blur_first - when replace_convs is True, blur input before the associated convolution. When set to False, the convolution is applied with a stride of 1 before the blurring, resulting in significant overhead (though more closely matching the paper).

Applicable Settings

Applicable whenever using a strided convolution or a local max pooling layer, which mainly occur in vision settings. We have currently implemented it for the PyTorch operators MaxPool2d and Conv2d.

Example Effects

The original paper showed accuracy gains of around 0.5-1% on ImageNet for various networks. A subsequent paper demonstrated similar gains for ImageNet, as well as significant improvements on instance segmentation on MS COCO. The latter paper also showed improvements in semantic segmentation metrics on PACAL VOC2012 and Cityscapes. Lee et al. have also reproduced ImageNet accuracy gains, especially when applying Blurpool only to strided convolutions.

Implementation Details

For max pooling, we replace torch.nn.MaxPool2d instances with instances of a custom nn.Module subclass that decouples the computation of the max within a given spatial window from the pooling and adds a spatial low-pass filter in between. This change roughly doubles the data movement required for the op, although shouldn’t add significant overhead unless there are many maxpools in the network.

For convolutions, we replace strided torch.nn.Conv2d instances with a custom module class that 1) applies a low-pass filter to the input, and 2) applies a copy of the original convolution operation. Depending the value of the blur_first parameter, the strided low-pass filtering can happen either before or after the convolution. The former keeps the number of multiply-add operations in the convolution itself constant, and only adds the overhead of the low-pass filtering. The latter increases the number of multiply-add operations by a factor of np.prod(conv.stride) (e.g., 4 for a stride of (2, 2)). This more closely matches the approach used in the paper. Anecdotally, we’ve observed this version yielding a roughly 0.1% accuracy gain on ResNet-50 + ImageNet in exchange for a ~10% slowdown. Having blur_first=False is not as well characterized in our experiments as blur_first=True.

Our implementation deviates from the original paper in that we apply the low-pass filter and pooling before the nonlinearity, instead of after. This is because we (so far) have no reliable way of adding it after the nonlinearity in an architecture-agnostic way.

Suggested Hyperparameters

We weakly suggest setting blur_maxpools=Trueto match the configuration in the paper, since we haven’t observed a large benefit either way.

We suggest setting blur_first=True to avoid increased computational cost.

Considerations

This method can be understood in several ways:

  1. It improves the network’s invariance to small spatial shifts

  2. It reduces aliasing in the downsampling operations

  3. It adds a structural bias towards preserving low-spatial-frequency components of neural network activations

Consequently, it is likely to be useful on natural images, or other inputs that change slowly over their spatial/time dimension(s).

Composability

BlurPool tends to compose well with other methods. We are not aware of an example of its effects changing significantly as a result of other methods being present.

Acknowledgments

We thank Richard Zhang for helpful discussion.

Code

class composer.algorithms.blurpool.BlurPool(replace_convs: bool, replace_maxpools: bool, blur_first: bool)[source]

BlurPool adds anti-aliasing filters to convolutional layers to increase accuracy and invariance to small shifts in the input.

Runs on Event.INIT and should be applied both before the model has been moved to accelerators and before the model’s parameters have been passed to an optimizer.

Parameters
  • replace_convs – replace strided torch.nn.Conv2d modules with BlurConv2d modules

  • replace_maxpools – replace eligible torch.nn.MaxPool2d modules with BlurMaxPool2d modules.

  • blur_first – when replace_convs is True, blur input before the associated convolution. When set to False, the convolution is applied with a stride of 1 before the blurring, resulting in significant overhead (though more closely matching the paper). See BlurConv2d for further discussion.

apply(event: composer.core.event.Event, state: composer.core.state.State, logger: composer.core.logging.logger.Logger) Optional[int][source]

Adds anti-aliasing filters to the maxpools and/or convolutions

Parameters
  • event (Event) – the current event

  • state (State) – the current trainer state

  • logger (Logger) – the training logger

match(event: composer.core.event.Event, state: composer.core.state.State) bool[source]

Runs on Event.INIT

Parameters
  • event (Event) – The current event.

  • state (State) – The current state.

Returns

bool – True if this algorithm should run now.

class composer.algorithms.blurpool.BlurPoolHparams(replace_convs: bool = True, replace_maxpools: bool = True, blur_first: bool = True)[source]

See BlurPool

class composer.algorithms.blurpool.BlurConv2d(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, blur_first: bool = True)[source]

This module is a drop-in replacement for PyTorch’s Conv2d, but with an anti-aliasing filter applied.

The one new parameter is blur_first. When set to True, the anti-aliasing filter is applied before the underlying convolution, and vice-versa when set to False. This mostly makes a difference when the stride is greater than one. In the former case, the only overhead is the cost of doing the anti-aliasing operation. In the latter case, the Conv2d is applied with a stride of one to the input, and then the anti-aliasing is applied with the provided stride to the result. Setting the stride of the convolution to 1 can greatly increase the computational cost. E.g., replacing a stride of (2, 2) with a stride of 1 increases the number of operations by a factor of (2/1) * (2/1) = 4. However, this approach most closely matches the behavior specified in the paper.

This module should only be used to replace strided convolutions.

See the associated paper for more details, experimental results, etc.

See also: blur_2d().

class composer.algorithms.blurpool.BlurMaxPool2d(kernel_size: Union[int, Tuple[int, int]], stride: Optional[Union[int, Tuple[int, int]]] = None, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, ceil_mode: bool = False)[source]

This module is a (nearly) drop-in replacement for PyTorch’s MaxPool2d, but with an anti-aliasing filter applied.

The only API difference is that the parameter return_indices is not available, because it is ill-defined when using anti-aliasing.

See the associated paper for more details, experimental results, etc.

See also: blur_2d().

class composer.algorithms.blurpool.BlurPool2d(stride: Union[int, Tuple[int, int]] = 2, padding: Union[int, Tuple[int, int]] = 1)[source]

This module just calls blur_2d() in forward using the provided arguments.

composer.algorithms.blurpool.blur_2d(input: torch.Tensor, stride: Union[int, Tuple[int, int]] = 1, filter: Optional[torch.Tensor] = None) torch.Tensor[source]

Apply a spatial low-pass filter.

Parameters
  • input – a 4d tensor of shape NCHW

  • stride – stride(s) along H and W axes. If a single value is passed, this value is used for both dimensions.

  • padding – implicit zero-padding to use. For the default 3x3 low-pass filter, padding=1 (the default) returns output of the same size as the input.

  • filter – a 2d or 4d tensor to be cross-correlated with the input tensor at each spatial position, within each channel. If 4d, the structure is required to be (C, 1, kH, kW) where C is the number of channels in the input tensor and kH and kW are the spatial sizes of the filter.

By default, the filter used is:

[1 2 1]
[2 4 2] * 1/16
[1 2 1]
composer.algorithms.blurpool.apply_blurpool(model: torch.nn.modules.module.Module, replace_convs: bool = True, replace_maxpools: bool = True, blur_first: bool = True) None[source]

Add anti-aliasing filters to the strided torch.nn.Conv2d and/or torch.nn.MaxPool2d modules within model.

Must be run before the model has been moved to accelerators and before the model’s parameters have been passed to an optimizer.

Parameters
  • model – model to modify

  • replace_convs – replace strided torch.nn.Conv2d modules with BlurConv2d modules

  • replace_maxpools – replace eligible torch.nn.MaxPool2d modules with BlurMaxPool2d modules.

  • blur_first – for replace_convs, blur input before the associated convolution. When set to False, the convolution is applied with a stride of 1 before the blurring, resulting in significant overhead (though more closely matching the paper). See BlurConv2d for further discussion.