Skip to main content

Run a TorchScript model with Mojo

Updated: |

17 min read

The Mojo API for MAX Engine helps extend performance gains from the MAX Engine runtime into your application, by enabling you to build your entire inference application in high-performance Mojo code.

With any of the MAX Engine API libraries (also available in Python and C), you can run inference with models from PyTorch and ONNX at incredible speeds on a wide range of hardware.

As we'll show in more detail below, there are 3 essential lines of code that you need to execute your model:

from max import engine

def main():
# Load your model:
session = engine.InferenceSession()
model = session.load(model_path)

# Get the inputs, then run an inference:
outputs = model.execute(inputs)
# Process the output here.
from max import engine

def main():
# Load your model:
session = engine.InferenceSession()
model = session.load(model_path)

# Get the inputs, then run an inference:
outputs = model.execute(inputs)
# Process the output here.

There's certainly more ceremony required to support these calls, but these are the basic APIs you need to know.

Now, let's walk through each step required to load and run a BERT model from PyTorch.

Set up the project environment

After you install Magic, create a new Python project and install the dependencies:

magic init bert-project --format mojoproject && cd bert-project
magic init bert-project --format mojoproject && cd bert-project

Add MAX as a conda package:

magic add max
magic add max

We also use the Python transformers library to process inputs and outputs, and PyTorch to convert the model to TorchScript. So let's add these packages from PyPI:

magic add "max~=24.5" "pytorch==2.4.0" "transformers==4.40.1"
magic add "max~=24.5" "pytorch==2.4.0" "transformers==4.40.1"

Now you can start a shell in the environment and see your Mojo version:

magic shell
magic shell
mojo --version
mojo --version

Download the TorchScript model

You can download the BERT TorchScript model used in this tutorial from our GitHub repo with these commands, which saves the model to your current directory:

git clone https://github.com/modularml/max.git
git clone https://github.com/modularml/max.git
python3 max/examples/inference/common/bert-torchscript/download-model.py \
-o bert-mlm.torchscript --mlm
python3 max/examples/inference/common/bert-torchscript/download-model.py \
-o bert-mlm.torchscript --mlm

Confirm that bert-mlm.torchscript was created. Then you can delete the max repo:

rm -rf max
rm -rf max

Import Mojo modules

Everything we need to run an inference with MAX Engine comes from the max.engine package. The rest are supporting APIs from the Mojo standard library.

from max.engine import InputSpec, InferenceSession
from pathlib import Path
from python import Python
from max.tensor import TensorSpec
from max.engine import InputSpec, InferenceSession
from pathlib import Path
from python import Python
from max.tensor import TensorSpec

In case you're new to Mojo, it's important to know that Mojo requires a main() function as the program entry point.

So, from here on out, imagine that the all the code runs inside this function:

def main():
# The rest of the code goes here
def main():
# The rest of the code goes here

The first thing we want to do inside the main() function is load any Python modules we plan to use. In this case, we're going to use HuggingFace Transformers to encode/decode our text strings, so let's load that Python module:

# This is equivalent to `import transformers` in Python
transformers = Python.import_module("transformers")
# This is equivalent to `import transformers` in Python
transformers = Python.import_module("transformers")

The transformers variable behaves like a Python module name from now on but it is still just a variable, which is scoped to the current function. If you want to use transformers in a function other than main(), then you need to put this line inside that function instead of main().

Load the model

First, let's load and compile the model in MAX Engine using an InferenceSession.

Define input specs

When you're using a PyTorch model (it must be in TorchScript format), you need to specify the input specifications for each of the model inputs before you can compile the model.

For each input, you need to create a list of TensorSpec values, and pass it to InferenceSession.load().

Here's how you can specify the input specs for the BERT model:

batch = 1
seqlen = 128

input_ids_spec = TensorSpec(DType.int64, batch, seqlen)
token_type_ids_spec = TensorSpec(DType.int64, batch, seqlen)
attention_mask_spec = TensorSpec(DType.int64, batch, seqlen)
input_specs = List[InputSpec]()

input_specs.append(input_ids_spec)
input_specs.append(attention_mask_spec)
input_specs.append(token_type_ids_spec)
batch = 1
seqlen = 128

input_ids_spec = TensorSpec(DType.int64, batch, seqlen)
token_type_ids_spec = TensorSpec(DType.int64, batch, seqlen)
attention_mask_spec = TensorSpec(DType.int64, batch, seqlen)
input_specs = List[InputSpec]()

input_specs.append(input_ids_spec)
input_specs.append(attention_mask_spec)
input_specs.append(token_type_ids_spec)

Next, you'll load the model with these input specs.

Load and compile the model

Now we instantiate an InferenceSession and then load-and-compile the model by passing the model path to load() (if you're loading a TorchScript model, also pass in input_specs):

model_path = Path("bert-mlm.torchscript")
session = InferenceSession()
model = session.load(model_path, input_specs=input_specs)
model_path = Path("bert-mlm.torchscript")
session = InferenceSession()
model = session.load(model_path, input_specs=input_specs)

Prepare the input

This is your usual pre-processing step to prepare input for the model. For the BERT model, we need to process the text input into a sequence of tokens.

Because Mojo is designed as the best way to extend Python, we can leverage all of the world's amazing Python libraries in our Mojo project. For this task, we need a string tokenizer, so we're using the transformers.AutoTokenizer from the 🤗 Transformers API:

INPUT=String("Paris is the [MASK] of France.")

tokenizer = transformers.AutoTokenizer.from_pretrained(
"bert-base-uncased"
)

# Get the maximum sequence length from the model's output metadata
output_spec = model.get_model_output_metadata()[0]
max_seqlen = output_spec[1].value()

# Tokenize the input text
inputs = tokenizer(
text=INPUT,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=max_seqlen,
return_tensors="np",
)
INPUT=String("Paris is the [MASK] of France.")

tokenizer = transformers.AutoTokenizer.from_pretrained(
"bert-base-uncased"
)

# Get the maximum sequence length from the model's output metadata
output_spec = model.get_model_output_metadata()[0]
max_seqlen = output_spec[1].value()

# Tokenize the input text
inputs = tokenizer(
text=INPUT,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=max_seqlen,
return_tensors="np",
)

Run an inference

The tokenized inputs we get from Transformers is a dictionary in which each input name (each key) is mapped to a NumPy array (the input tensor). Currently, Mojo doesn't have complete support for keyword arguments in functions, so we need to manually unpack this dictionary and pass each input to execute().

In this case, we're calling the overloaded version of execute() that accepts each input as a name and a PythonObject value (the NumPy array):

input_ids = inputs["input_ids"]
token_type_ids = inputs["token_type_ids"]
attention_mask = inputs["attention_mask"]

# Now we can run inference
outputs = model.execute("input_ids", input_ids,
"token_type_ids", token_type_ids,
"attention_mask", attention_mask)
input_ids = inputs["input_ids"]
token_type_ids = inputs["token_type_ids"]
attention_mask = inputs["attention_mask"]

# Now we can run inference
outputs = model.execute("input_ids", input_ids,
"token_type_ids", token_type_ids,
"attention_mask", attention_mask)

The output from execute() is a TensorMap, which we'll now process to get our results.

Process the outputs

Now we need to decode the predicted token into a string. First, we need to write an argmax() function that takes all the output logits and finds the vocabulary index position for each predicted token:

def argmax(t: Tensor) -> List[Int]:
var res = List[Int](capacity=t.dim(1))
for i in range(t.dim(1)):
var max_val = Scalar[t.type].MIN
var max_idx = 0
for j in range(t.dim(2)):
if t[0, i, j] > max_val:
max_val = t[0, i, j]
max_idx = j
res.append(max_idx)
return res
def argmax(t: Tensor) -> List[Int]:
var res = List[Int](capacity=t.dim(1))
for i in range(t.dim(1)):
var max_val = Scalar[t.type].MIN
var max_idx = 0
for j in range(t.dim(2)):
if t[0, i, j] > max_val:
max_val = t[0, i, j]
max_idx = j
res.append(max_idx)
return res

Now, back inside the main() function, we'll use this new argmax() function and the Transformers API to decode the masked token:

logits = outputs.get[DType.float32]("result0")

# Find the index of the mask token
mask_idx = -1
for i in range(len(input_ids[0])):
if input_ids[0][i] == tokenizer.mask_token_id:
mask_idx = i

# Decode the predicted token
predicted_token_id = argmax(logits)[mask_idx]
decoded_result = tokenizer.decode(
predicted_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)

print("input text: ", INPUT)
print("filled mask: ", INPUT.replace("[MASK]", str(decoded_result)))
logits = outputs.get[DType.float32]("result0")

# Find the index of the mask token
mask_idx = -1
for i in range(len(input_ids[0])):
if input_ids[0][i] == tokenizer.mask_token_id:
mask_idx = i

# Decode the predicted token
predicted_token_id = argmax(logits)[mask_idx]
decoded_result = tokenizer.decode(
predicted_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)

print("input text: ", INPUT)
print("filled mask: ", INPUT.replace("[MASK]", str(decoded_result)))

The following is the expected output:

    input text:  Paris is the [MASK] of France.
filled mask: Paris is the capital of France.
    input text:  Paris is the [MASK] of France.
filled mask: Paris is the capital of France.

Now you're running models with Mojo! 🔥

For more details about the API, see the Mojo MAX Engine reference.

For more example code, see our GitHub repo.

Did this tutorial work for you?