Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿฅก Exporting for Inference#

Composer models are also torch.nn.Module, and thus can be exported like any other PyTorch module. In this tutorial, we walk through how to export your models into various common formats: ONNX, TorchScript, and torch.fx. For more detailed options and configuration settings, please consult the linked documentation.

Algorithm compatibility#

Some of our algorithms alter the model architecture in ways that may render them incompatible with some of the export procedures above. For example, BlurPool replaces some instances of Conv2d with BlurConv2d layers which are not compatible with torch.fx as they have data-dependant control flow.

The following table shows which algorithms are compatible with which export formats for inference.

torchscript

torch.fx

ONNX

apply_blurpool

โœ“

โœ“

apply_factorization

โœ“

โœ“

apply_ghost_batchnorm

โœ“

โœ“

apply_squeeze_excite

โœ“

โœ“

โœ“

apply_stochastic_depth

โœ“

โœ“

โœ“

apply_channels_last

โœ“

โœ“

โœ“

Prerequisites#

First, we install composer:

[ ]:
%pip install mosaicml

Create the model#

First, we create the model weโ€™d like to export, which in this case is based on a ResNet-50 from torchvision, but with our SqueezeExcite algorithm applied, which adds SqueezeExcite modules after certain Conv2d layers.

[ ]:
from torchvision.models import resnet
from composer.models import ComposerClassifier
import composer.functional as cf

model = ComposerClassifier(module=resnet.resnet50())
model = cf.apply_squeeze_excite(model)

# switch to eval mode
model.eval()

Printing the model shows the new Bottleneck layers:

[ ]:
print(model)

ONNX#

ONNX is a popular model format that can then be consumed by many third-party tools (e.g. TensorRT, OpenVINO) to optimize the model for specific hardware devices.

Note: ONNX does not have a prebuild wheel for Mac M1/M2 chips yet, so is not pip installable. Skip this section if you are running on a Mac laptop.

[ ]:
%pip install onnx
%pip install onnxruntime

The ComposerClassifierโ€™s forward method takes as input a pair of tensors (input, label), so we create a dummy tensor sizes for the ONNX export:

[ ]:
import torch

input = (torch.rand(4, 3, 112, 112), torch.Tensor())

Then we are ready to run the export with:

[ ]:
import os

torch.onnx.export(
    model=model,
    args=(input,),
    f='model.onnx',
    input_names=['input'],
    output_names=['output'],
)

Letโ€™s load the model and check that everything was exported properly.

[ ]:
import onnx

onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)

Lastly, we can run inference with the model and check that the model indeed runs.

[ ]:
import onnxruntime as ort
import numpy as np

# run inference
ort_session = ort.InferenceSession('model.onnx')
outputs = ort_session.run(
    None,
    {'input': input[0].numpy()},
)

print(f"The predicted classes are {np.argmax(outputs[0], axis=1)}")

Note: As the model is randomly initialized, and the input tensor is random, the output classes in this example have no meaning.

Torchscript#

Torchscript creates models from PyTorch code that can be saved and also optimized for deployment, and is the tooling is native to pytorch. The below command

[ ]:
import torch
import numpy as np

input = (torch.rand(4, 3, 112, 112), torch.Tensor())

scripted_model = torch.jit.script(model)
scripted_model.eval()

We can then run inference by passing in our input:

[ ]:
output = scripted_model(input)
output.shape
print(f"The predicted classes are {torch.argmax(output, dim=1)}")

The compiled model can also be saved using torch.jit:

[ ]:
torch.jit.save(scripted_model, 'scripted_model.pt')

Torch.fx#

FX is a recent toolkit to transform pytorch modules that allows for advanced graph manipulation and code generation capabilities. Eventually, pytorch will be adding quantization with FX (e.g. see FX Graph Mode Quantization) and other optimization procedures. Composer is also starting to add algorithms that use torch.fx in for graph optimization, so look forward to more of these in the future!

Tracing a model with torch.fx is fairly straightforward:

[ ]:
traced_model = torch.fx.symbolic_trace(model)

Then, we can see all the nodes in the graph:

[ ]:
traced_model.graph.print_tabular()

And also run inference:

[ ]:
output = traced_model(input)
print(f"The predicted classes are {torch.argmax(output, dim=1)}")

torch.fx is powerful, but one of the key limitations of this method is that it does not support dynamic control flow (e.g. if statements or loop that are data-dependant). Therefore, some algorithms, such as BlurPool, are currently not supported. We have ongoing work to bring torch.fx support to all our algorithms.