Supported model formats
MAX Engine supports model formats provided by PyTorch, TensorFlow, and ONNX. However, we don't support every format from PyTorch and TensorFlow, and all models must be loaded from a file.
Currently, we support PyTorch's TorchScript format, TensorFlow's SavedModel format, and ONNX format.
If you've never heard of TorchScript, SavedModel, or ONNX, don't worry, that's what this page is here for.
Export a PyTorch model to TorchScript
You might be used to saving PyTorch models as a file with
torch.save()
,
but this creates a Python pickle
object, which can only be used
with Python. This won't work because MAX Engine doesn't execute models with
Python. You instead need to create a TorchScript file.
As the name implies, TorchScript is actually a language—it's a subset of the Python language (technically, it's an intermediate representation) that provides a serialization format for PyTorch models so they can run in non-Python environments. However, you can't load TorchScript code directly into MAX Engine—you must save it as a TorchScript file.
As we'll describe in the following sections, PyTorch provides two ways to
convert a PyTorch model (a
torch.nn.Module
object) to TorchScript: you can either "script" or "trace" the model.
Script a PyTorch model
When you "script" a PyTorch model, it means that you parse the Python code and
convert it into the TorchScript representation using
torch.jit.script()
.
So, although it's accurate to say TorchScript is a language, you still write
everything in Python (your code lives in a .py
file). However, not all Python
code can be successfully converted into TorchScript.
In the best-case scenario, all the Python code in your PyTorch model is already
compatible with TorchScript and calling
torch.jit.script()
just works. In other cases, it might require that you modify the Python code so
it uses only the Python features that are available in the TorchScript
language—most
notably, TorchScript enforces static types.
When your PyTorch model is compatible with TorchScript, calling
torch.jit.script()
returns either a
ScriptModule
or
ScriptFunction
,
which you can save as a file with
torch.jit.save()
.
Fortunately, many PyTorch models are already compatible with TorchScript, so you can simply instantiate them, convert them, and save them as a TorchScript file like this:
import torch
import torchvision.models as models
r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')
import torch
import torchvision.models as models
r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')
This resnet50.torchscript
file is now ready to load into MAX
Engine.
If you're writing your own PyTorch model and want to make it compatible with TorchScript, see the PyTorch docs about TorchScript for more details.
Trace a PyTorch model
In some cases, scripting a model might not work so easily, and making it work
could require significant code rewrites. In this case, you can instead "trace"
the graph with
torch.jit.trace()
.
Tracing the model means PyTorch actually executes the model with sample inputs
and records all operations that are invoked. PyTorch adds the recorded
operations to a TorchScript representation of the graph—either a
ScriptModule
or
ScriptFunction
—that
you can save as a TorchScript file with
torch.jit.save()
.
For example, the following code shows how you can trace a model from 🤗
Transformers and save it as a TorchScript file. Notice that tracing a model
requires that you provide sample input data so that
torch.jit.trace()
can actually execute the model (it can be random data as long as it matches
the input shape and type).
import torch
from transformers import RobertaForSequenceClassification
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)
torch.jit.save(traced_model, "roberta.torchscript")
import torch
from transformers import RobertaForSequenceClassification
HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)
batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)
torch.jit.save(traced_model, "roberta.torchscript")
This roberta.torchscript
file is now ready to load into MAX
Engine.
For more information, see the PyTorch docs about TorchScript.
Export a TensorFlow model to SavedModel
Convert a model to ONNX
MAX Engine supports models in the ONNX format, which you can create from either PyTorch or TensorFlow. If you already have an ONNX model, you can directly load it into MAX Engine.
If you don't already have an ONNX model, then we recommend that you instead create a TorchScript file or create a SavedModel, as described above. This will save you a bit of time and confusion.
That's not to say ONNX isn't any good, because it is good, and it provides plenty of value for a wide range of production use cases. However, if your intent is to use MAX Engine, then using ONNX doesn't really help you because you can get the format you need straight from PyTorch or TensorFlow, as shown above.
Load a model into MAX Engine
Once you have your model as a TorchScript, SavedModel, or ONNX file, you can load it for execution in MAX Engine.
If you're using a TorchScript file, you need to first specify the input shape.
If you're using a SavedModel or ONNX file, just pass the file path to MAX Engine, as shown in the guides to run inference with Python, with C, and with Mojo.
Specify TorchScript input specs
Loading a TorchScript model requires an extra step because the MAX Engine compiler must know the model's input shape, rank, and data type, which is absent in a TorchScript model. Thus, when loading a TorchScript model, you must provide the shape, rank, and data type as "input specs." If the model supports inputs with dynamic shapes, you can specify those in the input specs and MAX Engine will optimize the model for any inputs that match the shape.
The exact syntax to specify the input specs is different for each API:
-
In Mojo and Python, you need to specify the
input_specs
keyword argument toInferenceSession.load
(Mojo doc, Python doc). For details, see the Mojo inferencing guide or Python inferencing guide. -
In C, you need to call
M_setTorchInputSpecs()
. For details, see the C inferencing guide.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!