๐ฅก Exporting for Inference#
Composer models are also torch.nn.Module
, and thus can be exported like any other PyTorch module. In this tutorial, we walk through how to export your models into various common formats: ONNX, TorchScript, and torch.fx. For more detailed options and configuration settings, please consult the linked documentation.
Algorithm compatibility#
Some of our algorithms alter the model architecture in ways that may render them incompatible with some of the export procedures above. For example, BlurPool replaces some instances of Conv2d
with BlurConv2d
layers which are not compatible with torch.fx
as they have data-dependant control flow.
The following table shows which algorithms are compatible with which export formats for inference.
torchscript |
torch.fx |
ONNX |
|
---|---|---|---|
apply_blurpool |
โ |
โ |
|
apply_factorization |
โ |
โ |
|
apply_ghost_batchnorm |
โ |
โ |
|
apply_squeeze_excite |
โ |
โ |
โ |
apply_stochastic_depth |
โ |
โ |
โ |
apply_channels_last |
โ |
โ |
โ |
Prerequisites#
First, we install composer:
[ ]:
%pip install mosaicml
Create the model#
First, we create the model weโd like to export, which in this case is based on a ResNet-50 from torchvision, but with our SqueezeExcite algorithm applied, which adds SqueezeExcite modules after certain Conv2d
layers.
[ ]:
from torchvision.models import resnet
from composer.models import ComposerClassifier
import composer.functional as cf
model = ComposerClassifier(module=resnet.resnet50())
model = cf.apply_squeeze_excite(model)
# switch to eval mode
model.eval()
Printing the model shows the new Bottleneck
layers:
[ ]:
print(model)
ONNX#
ONNX is a popular model format that can then be consumed by many third-party tools (e.g. TensorRT, OpenVINO) to optimize the model for specific hardware devices.
Note: ONNX does not have a prebuild wheel for Mac M1/M2 chips yet, so is not pip installable. Skip this section if you are running on a Mac laptop.
[ ]:
%pip install onnx
%pip install onnxruntime
The ComposerClassifier
โs forward method takes as input a pair of tensors (input, label)
, so we create a dummy tensor sizes for the ONNX export:
[ ]:
import torch
input = (torch.rand(4, 3, 112, 112), torch.Tensor())
Then we are ready to run the export with:
[ ]:
import os
torch.onnx.export(
model=model,
args=(input,),
f='model.onnx',
input_names=['input'],
output_names=['output'],
)
Letโs load the model and check that everything was exported properly.
[ ]:
import onnx
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
Lastly, we can run inference with the model and check that the model indeed runs.
[ ]:
import onnxruntime as ort
import numpy as np
# run inference
ort_session = ort.InferenceSession('model.onnx')
outputs = ort_session.run(
None,
{'input': input[0].numpy()},
)
print(f"The predicted classes are {np.argmax(outputs[0], axis=1)}")
Note: As the model is randomly initialized, and the input tensor is random, the output classes in this example have no meaning.
Torchscript#
Torchscript creates models from PyTorch code that can be saved and also optimized for deployment, and is the tooling is native to pytorch. The below command
[ ]:
import torch
import numpy as np
input = (torch.rand(4, 3, 112, 112), torch.Tensor())
scripted_model = torch.jit.script(model)
scripted_model.eval()
We can then run inference by passing in our input:
[ ]:
output = scripted_model(input)
output.shape
print(f"The predicted classes are {torch.argmax(output, dim=1)}")
The compiled model can also be saved using torch.jit
:
[ ]:
torch.jit.save(scripted_model, 'scripted_model.pt')
Torch.fx#
FX is a recent toolkit to transform pytorch modules that allows for advanced graph manipulation and code generation capabilities. Eventually, pytorch will be adding quantization with FX (e.g. see FX Graph Mode Quantization) and other optimization procedures. Composer is also starting to add algorithms that use torch.fx
in for graph optimization, so look forward to more of these in the future!
Tracing a model with torch.fx
is fairly straightforward:
[ ]:
traced_model = torch.fx.symbolic_trace(model)
Then, we can see all the nodes in the graph:
[ ]:
traced_model.graph.print_tabular()
And also run inference:
[ ]:
output = traced_model(input)
print(f"The predicted classes are {torch.argmax(output, dim=1)}")
torch.fx
is powerful, but one of the key limitations of this method is that it does not support dynamic control flow (e.g. if
statements or loop that are data-dependant). Therefore, some algorithms, such as BlurPool, are currently not supported. We have ongoing work to bring torch.fx
support to all our algorithms.