Skip to main content

Model formats

We designed MAX to simplify AI deployment for everybody, which means it can accelerate all kinds of models on all kinds of hardware. Depending on your task, maybe that means running a model from PyTorch and ONNX, or maybe that means running the latest generative AI (GenAI) models. In all cases, MAX provides you a solution to deploy with hardware flexibility and state-of-the-art inference performance.

This page explains each of the model formats that MAX supports, and how you can get them. Put simply, you can use MAX with the following formats:

TorchScript

If you're familiar with PyTorch, you might be used to saving your model with torch.save(). But this creates a Python pickle object, which can only be used from Python. This won't work because, under the hood, MAX doesn't execute models with Python. You instead need to create a TorchScript file that MAX can compile into an executable format that MAX understands.

As we'll describe in the following sections, PyTorch provides two ways to convert a PyTorch model (a torch.nn.Module object) to TorchScript: you can either "script" or "trace" the model.

Script a PyTorch model

When you "script" a PyTorch model, it means that you parse the Python code and convert it into the TorchScript representation using torch.jit.script(). Although it's accurate to say TorchScript is a language, you still write everything in Python (your code lives in a .py file). However, not all Python code can successfully convert to TorchScript.

In the best-case scenario, all the Python code in your PyTorch model is already compatible with TorchScript and calling torch.jit.script() just works. In other cases, you might need to modify the Python code so it uses only the Python features that are available in the TorchScript language—most notably, TorchScript enforces static types.

When your PyTorch model is compatible with TorchScript, calling torch.jit.script() returns either a ScriptModule or ScriptFunction, which you can save as a file with torch.jit.save().

Fortunately, a lot of PyTorch models are already compatible with TorchScript, so you can simply instantiate them, convert them, and save them as a TorchScript file like this:

import torch
import torchvision.models as models

r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')
import torch
import torchvision.models as models

r50 = models.resnet50(pretrained=True)
r50_scripted = torch.jit.script(r50)
torch.jit.save(r50_scripted, 'resnet50.torchscript')

This resnet50.torchscript file is now ready to load and execute with the MAX Engine API.

Trace a PyTorch model

In some cases, scripting a model as shown above might require significant code rewrites. In this case, you can instead trace the graph with torch.jit.trace().

Tracing the model means PyTorch executes the model with sample inputs and records all operations that are invoked. PyTorch adds the recorded operations to a TorchScript representation of the graph—either a ScriptModule or ScriptFunction—that you can save as a TorchScript file with torch.jit.save().

For example, the following code shows how you can trace a model from Hugging Face and save it as a TorchScript file.

import torch
from transformers import RobertaForSequenceClassification

HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)

batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)

torch.jit.save(traced_model, "roberta.torchscript")
import torch
from transformers import RobertaForSequenceClassification

HF_MODEL_NAME = "cardiffnlp/twitter-roberta-base-emotion-multilabel-latest"
model = RobertaForSequenceClassification.from_pretrained(HF_MODEL_NAME)

batch = 1
seqlen = 128
inputs = {
"input_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
"attention_mask": torch.ones((batch, seqlen), dtype=torch.float32),
"token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64),
}
with torch.no_grad():
traced_model = torch.jit.trace(
model, example_kwarg_inputs=dict(inputs), strict=False
)

torch.jit.save(traced_model, "roberta.torchscript")

Notice that tracing a model requires that you provide sample input data so that torch.jit.trace() can execute the model (it can be random data as long as it matches the input shape and type). For a complete code example of how to trace a model and run it with MAX, see our tutorial to run a TorchScript model with Python.

ONNX

ONNX (Open Neural Network Exchange) is a model format you can export from a variety of machine learning tools, including PyTorch, TensorFlow, Keras, Scikit-Learn, and more.

If you don't have an ONNX model yet, refer to the appropriate ML framework documentation about how to convert your model to ONNX. For example, if using TensorFlow, see the ONNX guide about how to convert TensorFlow to ONNX.

It's also easy to export a Hugging Face model to ONNX, using either a CLI tool or a Python API.

For a complete code example of how to export a Hugging Face model to ONNX with Python and then run inference with MAX, see our tutorial to run an ONNX model with Python.

MAX Graph

MAX Graph is our solution to build high-performance GenAI models such as large-language models (LLMs) in Python. When paired with the MAX compiler and runtime, a MAX Graph model delivers the state-of-the-art performance you'd expect from point-solution AI libraries written in C or C++, but with less code that’s more readable.

We built MAX Graph because, although MAX can execute off-the-shelf models from PyTorch and ONNX faster than the default runtimes, these ML frameworks can't do everything. Specifically, they have not kept up with the performance demands of GenAI model.

It all starts with the Graph object for creating acyclic computation graphs:

  1. Instantiate a graph, specifying the input shape as a TensorType.
  2. Build the graph by chaining ops functions. Each function takes and returns a Value object.
  3. Add the final Value to the graph using the output() method.

This structure allows you to define complex computations as a series of interconnected operations, forming an AI model graph.

For a step-by-step guide to building a model with MAX Graph, see our tutorial to get started with MAX Graph in Python. Or, check out our implementation of some GenAI models such as Llama3 and Mistral in GitHub.

Get started