composer.algorithms.cutout.cutout#
Core CutOut classes and functions.
Functions
See |
Classes
Cutout is a data augmentation technique that works by masking out one or more square regions of an input image. |
- class composer.algorithms.cutout.cutout.CutOut(num_holes=1, length=0.5, uniform_sampling=False)[source]#
Bases:
composer.core.algorithm.Algorithm
Cutout is a data augmentation technique that works by masking out one or more square regions of an input image.
This implementation cuts out the same square from all images in a batch.
Example
from composer.algorithms import CutOut from composer.trainer import Trainer cutout_algorithm = CutOut(num_holes=1, length=0.25) trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration="1ep", algorithms=[cutout_algorithm], optimizers=[optimizer] )
- Parameters
num_holes (int, optional) โ Integer number of holes to cut out. Default:
1
.length (float, optional) โ Relative side length of the masked region. If specified,
length
is interpreted as a fraction ofH
andW
, and the resulting box is a square with side lengthlength * min(H, W)
. Must be in the interval \((0, 1)\). Default:0.5
.
- composer.algorithms.cutout.cutout.cutout_batch(input, num_holes=1, length=0.5, uniform_sampling=False)[source]#
See
CutOut
.- Parameters
input (Image or Tensor) โ Image or batch of images. If a
torch.Tensor
, must be a single image of shape(C, H, W)
or a batch of images of shape(N, C, H, W)
.num_holes โ Integer number of holes to cut out. Default:
1
.length (float, optional) โ Relative side length of the masked region. If specified,
length
is interpreted as a fraction ofH
andW
, and the resulting box is a square with side lengthlength * min(H, W)
. Must be in the interval \((0, 1)\). Default:0.5
.uniform_sampling (bool, optional) โ If
True
, sample the bounding box such that each pixel has an equal probability of being masked. IfFalse
, defaults to the sampling used in the original paper implementation. Default:False
.
- Returns
X_cutout โ Batch of images with
num_holes
square holes with dimension determined bylength
replaced with zeros.
Example
from composer.algorithms.cutout import cutout_batch new_input_batch = cutout_batch(X_example, num_holes=1, length=0.25)