Skip to main content

Python module

engine

The APIs in this module allow you to run inference with MAX Engine—a graph compiler and runtime that accelerates your AI models on a wide variety of hardware. You can load models from PyTorch or ONNX, or those built with the Graph API.

Regardless of your model format, you can run an inference with just a few lines of code:

  1. Create an InferenceSession.
  2. Load a model with InferenceSession.load(), which returns a Model.
  3. Run the model by passing your input to Model.execute(), which returns the output.

That’s it! For all the code details, check out the following tutorials:

InferenceSession

class max.engine.InferenceSession(num_threads: int | None = None, devices: list[max.driver.driver.Device] = [Device(_device=<max._driver.core.Device object>, id=-1)], *, custom_extensions: ~typing.List[str | ~pathlib.Path | ~typing.Any] | str | ~pathlib.Path | ~typing.Any | None = None)

Manages an inference session in which you can load and run models.

You need an instance of this to load a model as a Model object. For example:

session = engine.InferenceSession()
model_path = Path('bert-base-uncased')
model = session.load(model_path)
session = engine.InferenceSession()
model_path = Path('bert-base-uncased')
model = session.load(model_path)
  • Parameters:

    num_threads – Number of threads to use for the inference session. This defaults to the number of physical cores on your machine.

load()

load(model: str | Path | Any, *, custom_extensions: List[str | Path | Any] | str | Path | Any | None = None, custom_ops_path: str | None = None, input_specs: list[max.engine.api.TorchInputSpec] | None = None, weights_registry: dict[str, max.driver.tensor.DLPackArray] | None = None) → Model

Loads a trained model and compiles it for inference.

Note: PyTorch models must be in TorchScript format.

  • Parameters:

    • model – Path to a model, or a TorchScript model instance. May be a TorchScript model or an ONNX model.

    • custom_extensions – The extensions to load for the model. Supports paths to .mojopkg custom ops, .so custom op libraries for PyTorch and .pt torchscript files for torch metadata libraries. Supports TorchMetadata and torch.jit.ScriptModule objects for torch metadata libraries without serialization.

    • custom_ops_path – The path to your custom ops Mojo package. Deprecated, use custom_extensions instead.

    • input_specs

      The tensor specifications (shape and data type) for each of the model inputs. This is required when loading serialized TorchScript models because they do not include type and shape annotations.

      For example:

      session = engine.InferenceSession()
      model = session.load(
      "clip-vit.torchscript",
      input_specs = [
      engine.TorchInputSpec(
      shape=[1, 16], dtype=DType.int32
      ),
      engine.TorchInputSpec(
      shape=[1, 3, 224, 224], dtype=DType.float32
      ),
      engine.TorchInputSpec(
      shape=[1, 16], dtype=DType.int32
      ),
      ],
      )
      session = engine.InferenceSession()
      model = session.load(
      "clip-vit.torchscript",
      input_specs = [
      engine.TorchInputSpec(
      shape=[1, 16], dtype=DType.int32
      ),
      engine.TorchInputSpec(
      shape=[1, 3, 224, 224], dtype=DType.float32
      ),
      engine.TorchInputSpec(
      shape=[1, 16], dtype=DType.int32
      ),
      ],
      )

      If the model supports an input with dynamic shapes, use None as the dimension size in shape.

    • weights_registry – A mapping from names of model weights’ names to their values. The values are currently expected to be dlpack arrays. If an array is a read-only numpy array, the user must ensure that its lifetime extends beyond the lifetime of the model.

  • Returns:

    The loaded model, compiled and ready to execute.

  • Raises:

    RuntimeError – If the path provided is invalid.

set_debug_print_options()

set_debug_print_options(style: str | ~max._engine.core.PrintStyle = <PrintStyle.COMPACT: 0>, precision: int = 6, output_directory: str = '')

Sets the debug print options.

See Value.print.

This affects debug printing across all model execution using the same InferenceSession.

Tensors saved with BINARY can be loaded using max.driver.MemmapTensor(), but you will have to provide the expected dtype and shape.

Tensors saved with BINARY_MAX_CHECKPOINT are saved with the shape and dtype information, and can be loaded with max.driver.tensor.load_max_tensor().

Warning: Even with style set to NONE, debug print ops in the graph can stop optimizations. If you see performance issues, try fully removing debug print ops.

  • Parameters:

    • style – How the values will be printed. Can be COMPACT, FULL, BINARY, BINARY_MAX_CHECKPOINT or NONE.
    • precision – If the style is FULL, the digits of precision in the output.
    • output_directory – If the style is BINARY, the directory to store output tensors.

stats_report

property stats_report*: Dict[str, Any]*

Metadata about model compilation (PyTorch only).

Prints a list of “fallback ops”, which are ops that could not be lowered to our internal dialect MO. Fallback ops have to be executed using the original framework (i.e. PyTorch), which makes the model much slower. This function is a good starting point for debugging model performance.

Model

class max.engine.Model

A loaded model that you can execute.

Do not instantiate this class directly. Instead, create it with InferenceSession.

devices

property devices*: list[max.driver.driver.Device]*

Returns the device objects used in the Model.

execute()

execute(*args: Tensor | MojoValue | int | float | bool | generic, copy_inputs_to_device: bool = True, output_device: Device | None = None) → list[Union[max.driver.tensor.Tensor, max._engine.core.MojoValue]]

Executes the model with the provided input and returns the outputs.

For example, if the model has one input tensor:

input_tensor = np.random.rand(1, 224, 224, 3)
model.execute(input_tensor)
input_tensor = np.random.rand(1, 224, 224, 3)
model.execute(input_tensor)
  • Parameters:

    • args – A list of input tensors. We currently support ndarray, torch.Tensor, and max.driver.Tensor inputs. All inputs will be copied to the device that the model is resident on prior to executing.
    • copy_inputs_to_device – Whether to copy all input tensors to the model’s device. Defaults to True. If set to False, input tensors will remain on whatever device they’re currently on, which the model must be prepared for.
    • output_device – The device to copy output tensors to. Defaults to None, in which case the tensors will remain resident on the same device as the model.
  • Returns:

    A list of output tensors and Mojo values. The output tensors will be resident on the execution device by default.

  • Raises:

    • RuntimeError – If the given input tensors’ shape don’t match what the model expects.zzz
    • TypeError – If the given input tensors’ dtype cannot be cast to what the model expects.
    • ValueError – If positional inputs are not one of the supported types, i.e. ndarray, torch.Tensor, and max.driver.Tensor.

execute_legacy()

execute_legacy(**kwargs: Any) → Dict[str, ndarray | dict | list | tuple]

Executes the model with a set of named tensors. This API is maintained primarily to support frameworks that require named inputs (i.e. ONNX).

NOTICE: This API does not support GPU inputs and is slated for deprecation.

For example, if the model has one input tensor named input0:

input_tensor = np.random.rand(1, 224, 224, 3)
model.execute_legacy(input0=input_tensor)
input_tensor = np.random.rand(1, 224, 224, 3)
model.execute_legacy(input0=input_tensor)
  • Parameters:

    kwargs – The input tensors, each specified with the appropriate tensor name as a keyword and its value as an ndarray. You can find the tensor names to use as keywords from input_metadata.

  • Returns:

    A dictionary of output values, each as an ndarray, Dict, List, or Tuple identified by its output name.

  • Raises:

    • RuntimeError – If the given input tensors’ name and shape don’t match what the model expects.
    • TypeError – If the given input tensors’ dtype cannot be cast to what the model expects.

input_devices

property input_devices*: List[Device]*

Device of the model’s input tensors, as a list of Device objects.

input_metadata

property input_metadata*: list[max.engine.api.TensorSpec]*

Metadata about the model’s input tensors, as a list of TensorSpec objects.

For example, you can print the input tensor names, shapes, and dtypes:

for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

output_devices

property output_devices*: List[Device]*

Device of the model’s output tensors, as a list of Device objects.

output_metadata

property output_metadata*: list[max.engine.api.TensorSpec]*

Metadata about the model’s output tensors, as a list of TensorSpec objects.

For example, you can print the output tensor names, shapes, and dtypes:

for tensor in model.ouput_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
for tensor in model.ouput_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

signature

property signature*: Signature*

Get input signature for model.

TensorSpec

class max.engine.TensorSpec(shape: List[int | str | None] | None, dtype: DType, name: str)

Defines the properties of a tensor, including its name, shape and data type.

For usage examples, see Model.input_metadata.

dtype

property dtype*: DType*

A tensor data type.

name

property name*: str*

A tensor name.

shape

property shape*: list[int] | None*

The shape of the tensor as a list of integers.

If a dimension size is unknown/dynamic (such as the batch size), its value is None.

TorchInputSpec

class max.engine.TorchInputSpec(shape: List[int | str | None] | None, dtype: DType)

Specifies valid input specification for a TorchScript model.

Before you load a TorchScript model, you must create an instance of this class for each input tensor, and pass them to the input_specs argument of InferenceSession.load().

For example code, see InferenceSession.load().

dtype

property dtype*: DType*

A torch input tensor data type.

shape

property shape*: List[int | str | None] | None*

The shape of the torch input tensor as a list of integers.

If a dimension size is unknown/dynamic (such as the batch size), the shape should be None.