Run a TorchScript model with Mojo
The Mojo API for MAX Engine helps extend performance gains from the MAX Engine runtime into your application, by enabling you to build your entire inference application in high-performance Mojo code.
With any of the MAX Engine API libraries (also available in Python and C), you can run inference with models from PyTorch and ONNX at incredible speeds on a wide range of hardware.
As we'll show in more detail below, there are 3 essential lines of code that you need to execute your model:
from max import engine
def main():
# Load your model:
session = engine.InferenceSession()
model = session.load(model_path)
# Get the inputs, then run an inference:
outputs = model.execute(inputs)
# Process the output here.
from max import engine
def main():
# Load your model:
session = engine.InferenceSession()
model = session.load(model_path)
# Get the inputs, then run an inference:
outputs = model.execute(inputs)
# Process the output here.
There's certainly more ceremony required to support these calls, but these are the basic APIs you need to know.
Now, let's walk through each step required to load and run a BERT model from PyTorch.
Set up the project environment
After you install Magic, create a new Python project and install the dependencies:
magic init bert-project --format mojoproject && cd bert-project
magic init bert-project --format mojoproject && cd bert-project
Add MAX as a conda package:
magic add max
magic add max
We also use the Python transformers
library to process inputs and outputs,
and PyTorch to convert the model to TorchScript. So let's add these packages
from PyPI:
magic add "max~=24.5" "pytorch==2.4.0" "transformers==4.40.1"
magic add "max~=24.5" "pytorch==2.4.0" "transformers==4.40.1"
Now you can start a shell in the environment and see your Mojo version:
magic shell
magic shell
mojo --version
mojo --version
Download the TorchScript model
You can download the BERT TorchScript model used in this tutorial from our GitHub repo with these commands, which saves the model to your current directory:
git clone https://github.com/modularml/max.git
git clone https://github.com/modularml/max.git
python3 max/examples/inference/common/bert-torchscript/download-model.py \
-o bert-mlm.torchscript --mlm
python3 max/examples/inference/common/bert-torchscript/download-model.py \
-o bert-mlm.torchscript --mlm
Confirm that bert-mlm.torchscript
was created. Then you can delete the max
repo:
rm -rf max
rm -rf max
Import Mojo modules
Everything we need to run an inference with MAX Engine comes from the
max.engine
package. The rest are supporting
APIs from the Mojo standard library.
from max.engine import InputSpec, InferenceSession
from pathlib import Path
from python import Python
from max.tensor import TensorSpec
from max.engine import InputSpec, InferenceSession
from pathlib import Path
from python import Python
from max.tensor import TensorSpec
In case you're new to Mojo, it's important to know that Mojo requires a
main()
function as the program entry point.
So, from here on out, imagine that the all the code runs inside this function:
def main():
# The rest of the code goes here
def main():
# The rest of the code goes here
The first thing we want to do inside the main()
function is load any Python
modules we plan to use. In this case, we're going to use HuggingFace
Transformers to encode/decode our text strings, so let's load that Python
module:
# This is equivalent to `import transformers` in Python
transformers = Python.import_module("transformers")
# This is equivalent to `import transformers` in Python
transformers = Python.import_module("transformers")
The transformers
variable behaves like a Python module name from now on but
it is still just a variable, which is scoped to the current function. If you
want to use transformers
in a function other than main()
, then
you need to put this line inside that function instead of main()
.
Load the model
First, let's load and compile the model in MAX Engine using an
InferenceSession
.
Define input specs
When you're using a PyTorch model (it must be in TorchScript format), you need to specify the input specifications for each of the model inputs before you can compile the model.
For each input, you need to create a list of
TensorSpec
values, and pass it to
InferenceSession.load()
.
Here's how you can specify the input specs for the BERT model:
batch = 1
seqlen = 128
input_ids_spec = TensorSpec(DType.int64, batch, seqlen)
token_type_ids_spec = TensorSpec(DType.int64, batch, seqlen)
attention_mask_spec = TensorSpec(DType.int64, batch, seqlen)
input_specs = List[InputSpec]()
input_specs.append(input_ids_spec)
input_specs.append(attention_mask_spec)
input_specs.append(token_type_ids_spec)
batch = 1
seqlen = 128
input_ids_spec = TensorSpec(DType.int64, batch, seqlen)
token_type_ids_spec = TensorSpec(DType.int64, batch, seqlen)
attention_mask_spec = TensorSpec(DType.int64, batch, seqlen)
input_specs = List[InputSpec]()
input_specs.append(input_ids_spec)
input_specs.append(attention_mask_spec)
input_specs.append(token_type_ids_spec)
Next, you'll load the model with these input specs.
Load and compile the model
Now we instantiate an InferenceSession
and then load-and-compile the model by passing the model path to
load()
(if you're
loading a TorchScript model, also pass in input_specs
):
model_path = Path("bert-mlm.torchscript")
session = InferenceSession()
model = session.load(model_path, input_specs=input_specs)
model_path = Path("bert-mlm.torchscript")
session = InferenceSession()
model = session.load(model_path, input_specs=input_specs)
Prepare the input
This is your usual pre-processing step to prepare input for the model. For the BERT model, we need to process the text input into a sequence of tokens.
Because Mojo is designed as the best way to extend Python, we can leverage all of the
world's amazing Python libraries in our Mojo project.
For this task, we need a string tokenizer, so we're using the
transformers.AutoTokenizer
from the 🤗 Transformers API:
INPUT=String("Paris is the [MASK] of France.")
tokenizer = transformers.AutoTokenizer.from_pretrained(
"bert-base-uncased"
)
# Get the maximum sequence length from the model's output metadata
output_spec = model.get_model_output_metadata()[0]
max_seqlen = output_spec[1].value()
# Tokenize the input text
inputs = tokenizer(
text=INPUT,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=max_seqlen,
return_tensors="np",
)
INPUT=String("Paris is the [MASK] of France.")
tokenizer = transformers.AutoTokenizer.from_pretrained(
"bert-base-uncased"
)
# Get the maximum sequence length from the model's output metadata
output_spec = model.get_model_output_metadata()[0]
max_seqlen = output_spec[1].value()
# Tokenize the input text
inputs = tokenizer(
text=INPUT,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=max_seqlen,
return_tensors="np",
)
Run an inference
The tokenized inputs
we get from Transformers is a dictionary in which each
input name (each key) is mapped to a NumPy array (the input tensor).
Currently, Mojo doesn't have complete support for keyword
arguments in functions, so we need to manually unpack
this dictionary and pass each input to
execute()
.
In this case, we're calling the overloaded version of
execute()
that
accepts each input as a name and a PythonObject
value (the NumPy array):
input_ids = inputs["input_ids"]
token_type_ids = inputs["token_type_ids"]
attention_mask = inputs["attention_mask"]
# Now we can run inference
outputs = model.execute("input_ids", input_ids,
"token_type_ids", token_type_ids,
"attention_mask", attention_mask)
input_ids = inputs["input_ids"]
token_type_ids = inputs["token_type_ids"]
attention_mask = inputs["attention_mask"]
# Now we can run inference
outputs = model.execute("input_ids", input_ids,
"token_type_ids", token_type_ids,
"attention_mask", attention_mask)
The output from execute()
is a TensorMap
, which
we'll now process to get our results.
Process the outputs
Now we need to decode the predicted token into a string. First, we need to
write an argmax()
function that takes all the output logits and finds the
vocabulary index position for each predicted token:
def argmax(t: Tensor) -> List[Int]:
var res = List[Int](capacity=t.dim(1))
for i in range(t.dim(1)):
var max_val = Scalar[t.type].MIN
var max_idx = 0
for j in range(t.dim(2)):
if t[0, i, j] > max_val:
max_val = t[0, i, j]
max_idx = j
res.append(max_idx)
return res
def argmax(t: Tensor) -> List[Int]:
var res = List[Int](capacity=t.dim(1))
for i in range(t.dim(1)):
var max_val = Scalar[t.type].MIN
var max_idx = 0
for j in range(t.dim(2)):
if t[0, i, j] > max_val:
max_val = t[0, i, j]
max_idx = j
res.append(max_idx)
return res
Now, back inside the main()
function, we'll use this new
argmax()
function and the Transformers API to decode the masked token:
logits = outputs.get[DType.float32]("result0")
# Find the index of the mask token
mask_idx = -1
for i in range(len(input_ids[0])):
if input_ids[0][i] == tokenizer.mask_token_id:
mask_idx = i
# Decode the predicted token
predicted_token_id = argmax(logits)[mask_idx]
decoded_result = tokenizer.decode(
predicted_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print("input text: ", INPUT)
print("filled mask: ", INPUT.replace("[MASK]", str(decoded_result)))
logits = outputs.get[DType.float32]("result0")
# Find the index of the mask token
mask_idx = -1
for i in range(len(input_ids[0])):
if input_ids[0][i] == tokenizer.mask_token_id:
mask_idx = i
# Decode the predicted token
predicted_token_id = argmax(logits)[mask_idx]
decoded_result = tokenizer.decode(
predicted_token_id,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
print("input text: ", INPUT)
print("filled mask: ", INPUT.replace("[MASK]", str(decoded_result)))
The following is the expected output:
input text: Paris is the [MASK] of France.
filled mask: Paris is the capital of France.
input text: Paris is the [MASK] of France.
filled mask: Paris is the capital of France.
Now you're running models with Mojo! 🔥
For more details about the API, see the Mojo MAX Engine reference.
For more example code, see our GitHub repo.
Did this tutorial work for you?
Thank you! We'll create more content like this.
Thank you for helping us improve!