Tip

This tutorial is available as a Jupyter notebook.

Open in Colab

๐Ÿฅก Exporting for Inference#

Composer provides model export support for inference using a dedicated export API and a callback. In this tutorial, we walk through how to export your models into various common formats: ONNX, TorchScript using the dedicated export API as well as Composerโ€™s callback mechanism. For more detailed options and configuration settings, please consult the linked documentation. In addition, if for any reason, above methods of exporting are not sufficient for your use case Composer models can be exported like any other PyTorch module since Composer models are also torch.nn.Module.

Prerequisites#

First, we install composer:

[ ]:
%pip install mosaicml

Create the model#

First, we create the model weโ€™d like to export, which in this case is based on a ResNet-50 from torchvision, but with our SqueezeExcite algorithm applied, which adds SqueezeExcite modules after certain Conv2d layers.

[ ]:
from torchvision.models import resnet
from composer.models import ComposerClassifier
import composer.functional as cf

model = ComposerClassifier(module=resnet.resnet50())
model = cf.apply_squeeze_excite(model)

# switch to eval mode
model.eval()

Torchscript Export using standalone API#

Torchscript creates models from PyTorch code that can be saved and also optimized for deployment, and is the tooling is native to pytorch.

The ComposerClassifierโ€™s forward method takes as input a pair of tensors (input, label), so we create a dummy tensors to run the model.

[ ]:
import torch

input = (torch.rand(4, 3, 224, 224), torch.Tensor())

output = model(input)

Now we run export using our standalone export API. Composer also supports exporting to an object store such as S3. Please checkout full documentation for export_for_inference API for help on using an object store.

[ ]:
import os
import tempfile
from composer.utils import export_for_inference

save_format = 'torchscript'
working_dir = tempfile.TemporaryDirectory()
model_save_path = os.path.join(working_dir.name, 'model.pt')

export_for_inference(model=model,
                     save_format=save_format,
                     save_path=model_save_path)

Let us check to make sure that the model exists in our working directory.

[ ]:
print(os.listdir(path=working_dir.name))

Let us reload the saved model and run inference on it. We also compare the results with the previously computed results on the same input to make sure .

[ ]:
scripted_model = torch.jit.load(model_save_path)
scripted_model.eval()
scripted_output = scripted_model(input)
print(torch.allclose(output, scripted_output))

Export using a callback#

Composer trainer also allows you to specify a export callback that automatically exports at the end of training. Since we will be training a model for a few epochs, we first create a dataloader with synthetic dataset for this tutorial.

[ ]:
from composer.datasets.synthetic import SyntheticBatchPairDataset
from torch.utils.data import DataLoader

dataset = SyntheticBatchPairDataset(total_dataset_size=8, data_shape=(3, 224, 224), num_classes=1000)
dataloader = DataLoader(dataset=dataset, batch_size=4)

Create the model#

We create the model we are training, which in this case is based on ResNet-50 from torchvision.

[ ]:
import os
from torchvision.models import resnet
from composer.models import ComposerClassifier

model = ComposerClassifier(module=resnet.resnet50())

Create export callback#

Now we create a callback that is used by the trainer to export model for inference. Since we already saw torchscript export using Composerโ€™s standalone export API, for this section, we are using onnx as export format to showcase both capabilties. However, both torchscript and onnx are supported with both ways of exporting. In either case, you can just change save_format 'onnx' or 'torchscript' to export in your desired format. ONNX is a popular model format that can then be consumed by many third-party tools (e.g. TensorRT, OpenVINO) to optimize the model for specific hardware devices.

Note: ONNX does not have a prebuild wheel for Mac M1/M2 chips yet, so is not pip installable. Skip this section if you are running on a Mac laptop.

[ ]:
import composer.functional as cf
from composer.callbacks import ExportForInferenceCallback
# change to 'torchscript' for exporting to torchscript format
save_format = 'onnx'
model_save_path = os.path.join(working_dir.name, 'model1.onnx')
export_callback = ExportForInferenceCallback(save_format=save_format, save_path=model_save_path)

Run Training#

Now we construct the trainer using this callback. The model is exported at the end of the training. In the later part of this tutorail we show model exporting from a checkpoint, so we also supply trainer save_folder and save_interval to save some checkpoints.

[ ]:
import torch
from composer import Trainer
from composer.algorithms import SqueezeExcite

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5)

trainer = Trainer(
    model=model,
    train_dataloader=dataloader,
    optimizers=optimizer,
    schedulers=scheduler,
    save_folder=working_dir.name,
    algorithms=[SqueezeExcite()],
    callbacks=[export_callback],
    max_duration='2ep',
    save_interval='1ep')
trainer.fit()

Let us list the content of the working_dir to check if the checkpoints and exported model is available.

[ ]:
print(os.listdir(path=working_dir.name))

Alternative way of exporting with trainer.export_for_inference#

[ ]:
model_save_path = os.path.join(working_dir.name, 'model2.onnx')

trainer.export_for_inference(save_format='onnx', save_path=model_save_path)

Let us list the content of the working_dir to see if this exported model is available.

[ ]:
print(os.listdir(path=working_dir.name))

Load and run exported ONNX model#

[ ]:
%pip install onnx
%pip install onnxruntime

Letโ€™s load the model and check that everything was exported properly.

[ ]:
import onnx

onnx_model = onnx.load(model_save_path)
onnx.checker.check_model(onnx_model)

Lastly, we can run inference with the model and check that the model indeed runs.

[ ]:
import onnxruntime as ort
import numpy as np

# run inference
ort_session = ort.InferenceSession(model_save_path)
outputs = ort_session.run(
    None,
    {'input': input[0].numpy()})
print(f"The predicted classes are {np.argmax(outputs[0], axis=1)}")

Note: As the model is randomly initialized, and the input tensor is random, the output classes in this example have no meaning.

Exporting from an existing checkpoint#

In this part of the tutorial, we will look at exporting a model from a previously created checkpoint that is stored locally. Composer also supports exporting from a checkpoint stored in an object store such as S3. Please checkout full documentation for export_for_inference API for using an object store.

Some of our algorithms alter the model architecture. For example, SqueezeExcite adds a channel-wise attention operator in CNNs and modifies model architecure. Therefore, we need to provide a function that takes the mode and applies the algorithm before we can load the model weights from a checkpoint. Functional form of SqueezeExcite does exactly that and we pass that as surgery_algs to the export_for_inference API.

[ ]:
from composer.utils import export_for_inference
# We call it model2.onnx to make it different from our previous export
model_save_path = os.path.join(working_dir.name, 'model2.onnx')
checkpoint_path = os.path.join(working_dir.name, 'ep2-ba4-rank0.pt')

model = ComposerClassifier(module=resnet.resnet50())

export_for_inference(model=model,
                     save_format=save_format,
                     save_path=model_save_path,
                     sample_input=(input,),
                     surgery_algs=[cf.apply_squeeze_excite],
                     load_path=checkpoint_path)

Let us list the content of the working_dir to check if the newly exported model is available.

[ ]:
print(os.listdir(path=working_dir.name))

Make sure the model loaded from a checkpoint produces the same results as before

[ ]:
ort_session = ort.InferenceSession(model_save_path)
new_outputs = ort_session.run(
    None,
    {'input': input[0].numpy()},
)
print(np.allclose(outputs[0], new_outputs[0], atol=1e-07))
[ ]:
# Clean up working directory
working_dir.cleanup()

Torch.fx#

FX is a recent toolkit to transform pytorch modules that allows for advanced graph manipulation and code generation capabilities. Eventually, pytorch will be adding quantization with FX (e.g. see FX Graph Mode Quantization) and other optimization procedures. Composer is also starting to add algorithms that use torch.fx in for graph optimization, so look forward to more of these in the future!

Tracing a model with torch.fx is fairly straightforward:

[ ]:
traced_model = torch.fx.symbolic_trace(model)

Then, we can see all the nodes in the graph:

[ ]:
traced_model.graph.print_tabular()

And also run inference:

[ ]:
output = traced_model(input)
print(f"The predicted classes are {torch.argmax(output, dim=1)}")

torch.fx is powerful, but one of the key limitations of this method is that it does not support dynamic control flow (e.g. if statements or loop that are data-dependant). Therefore, some algorithms, such as BlurPool, are currently not supported. We have ongoing work to bring torch.fx support to all our algorithms.

Algorithm compatibility#

Some of our algorithms alter the model architecture in ways that may render them incompatible with some of the export procedures above. For example, BlurPool replaces some instances of Conv2d with BlurConv2d layers which are not compatible with torch.fx as they have data-dependant control flow.

The following table shows which algorithms are compatible with which export formats for inference.

torchscript

torch.fx

ONNX

apply_blurpool

โœ“

โœ“

apply_factorization

โœ“

โœ“

apply_ghost_batchnorm

โœ“

โœ“

apply_squeeze_excite

โœ“

โœ“

โœ“

apply_stochastic_depth

โœ“

โœ“

โœ“

apply_channels_last

โœ“

โœ“

โœ“