๐ŸŽฐ Stochastic Depth (Sample)#

[How to Use] - [Suggested Hyperparameters] - [Technical Details] - [Attribution]

Computer Vision

Sample-wise stochastic depth is a regularization technique for networks with residual connections that probabilistically drops samples after the transformation function in each residual block. This means that different samples go through different combinations of blocks.

How to Use#

Functional Interface#

# Run the Stochastic Depth algorithm directly on the model using the Composer functional API

import composer.functional as cf
import torch
from torchvision.models import resnet50

# Training

# Stochastic depth can only be run on ResNet-50/101/152
model = resnet50()

opt = torch.optim.Adam(model.parameters())

# only need to pass in opt if apply_stochastic_depth is used after the optimizer
# creation; otherwise only the model needs to be passed in
cf.apply_stochastic_depth(model,
                          target_layer_name='ResNetBottleneck',
                          stochastic_method='sample',
                          drop_rate=0.2,
                          drop_distribution='linear',
                          optimizers=opt)

loss_fn = F.cross_entropy
model.train()

for epoch in range(10):
    for X, y in train_loader:
        y_hat = model(X)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        opt.zero_grad()

Composer Trainer#

# Instantiate the algorithm and pass it into the Trainer
# The trainer will automatically run it at the appropriate point in the training loop

from composer.algorithms import StochasticDepth
from composer.trainer import Trainer

# Train model

# Stochastic depth can only be run on ResNet-50/101/152
model = resnet50()

stochastic_depth = StochasticDepth(target_layer_name='ResNetBottleneck',
                                   stochastic_method='sample',
                                   drop_rate=0.2,
                                   drop_distribution='linear')

trainer = Trainer(model=model,
                  train_dataloader=train_dataloader,
                  max_duration='10ep',
                  algorithms=[stochastic_depth])
trainer.fit()

Implementation Details#

The Composer implementation of Stochastic Depth uses model surgery to replace residual bottleneck blocks with analogous stochastic versions. When training, samples are dropped after the transformation function in a residual block by multiplying the batch by a binary vector. The binary vector is generated by sampling independent Bernoulli distributions with probability (1 - drop_rate). After the samples are dropped, the skip connection is added as usual. During inference, no samples are dropped, but the batch of samples is scaled by (1 - drop_rate) to compensate for the drop frequency when training.

Suggested Hyperparameters#

We observed that drop_rate=0.1 and drop_distribution=linear yielded maximum accuracy improvements on both ResNet-50 and ResNet-101.

Technical Details#

For both ResNet-50 and ResNet-101 on ImageNet, we measure a +0.4% absolute accuracy improvement when using drop_rate=0.1 and drop_distribution=linear. The training wall-clock time is approximately 5% longer when using sample-wise stochastic depth

Attribution#

Deep Networks with Stochastic Depth by Gao Huang, Yu Sun, Zhuang Liu, Daniel Sedra, and Killian Weinberger. Published in ECCV in 2016.

EfficientNet model in the TPU Github repository from Google

EfficientNet model in gen-efficientnet-pytorch Github repository by Ross Wightman