Python module
engine
The APIs in this module allow you to run inference with MAX Engine—a graph compiler and runtime that accelerates your AI models on a wide variety of hardware.
InferenceSession
class max.engine.InferenceSession(num_threads: int | None = None, devices: list[Device] = [CPU], *, custom_extensions: CustomExtensionsType | None = None)
Manages an inference session in which you can load and run models.
You need an instance of this to load a model as a Model
object.
For example:
session = engine.InferenceSession()
model_path = Path('bert-base-uncased')
model = session.load(model_path)
session = engine.InferenceSession()
model_path = Path('bert-base-uncased')
model = session.load(model_path)
-
Parameters:
- num_threads – Number of threads to use for the inference session. This defaults to the number of physical cores on your machine.
- devices – A list of devices on which to run inference. Default is the host CPU only.
- custom_extensions – The extensions to load for the model.
Supports paths to .mojopkg custom ops, .so custom op libraries
for PyTorch and .pt torchscript files for torch metadata
libraries. Supports
TorchMetadata
andtorch.jit.ScriptModule
objects for torch metadata libraries without serialization.
devices
property devices*: list[max._core.driver.Device]*
A list of available devices.
gpu_profiling()
gpu_profiling(mode: GPUProfilingMode)
Enables end to end gpu profiling configuration.
load()
load(model: str | Path | Any, *, custom_extensions: list[Union[str, pathlib.Path, Any]] | str | Path | Any | None = None, custom_ops_path: str | None = None, input_specs: list[max._core.engine.TorchInputSpec] | None = None, weights_registry: Mapping[str, DLPackArray | ndarray[Any, dtype[_ScalarType_co]]] | None = None) → Model
Loads a trained model and compiles it for inference.
Note: PyTorch models must be in TorchScript format.
-
Parameters:
-
model – Path to a model, or a TorchScript model instance. May be a TorchScript model or an ONNX model.
-
custom_extensions – The extensions to load for the model. Supports paths to .mojopkg custom ops, .so custom op libraries for PyTorch and .pt torchscript files for torch metadata libraries. Supports
TorchMetadata
andtorch.jit.ScriptModule
objects for torch metadata libraries without serialization. -
custom_ops_path – The path to your custom ops Mojo package. Deprecated, use
custom_extensions
instead. -
input_specs –
The tensor specifications (shape and data type) for each of the model inputs. This is required when loading serialized TorchScript models because they do not include type and shape annotations.
For example:
session = engine.InferenceSession()
model = session.load(
"clip-vit.torchscript",
input_specs = [
engine.TorchInputSpec(
shape=[1, 16], dtype=DType.int32
),
engine.TorchInputSpec(
shape=[1, 3, 224, 224], dtype=DType.float32
),
engine.TorchInputSpec(
shape=[1, 16], dtype=DType.int32
),
],
)session = engine.InferenceSession()
model = session.load(
"clip-vit.torchscript",
input_specs = [
engine.TorchInputSpec(
shape=[1, 16], dtype=DType.int32
),
engine.TorchInputSpec(
shape=[1, 3, 224, 224], dtype=DType.float32
),
engine.TorchInputSpec(
shape=[1, 16], dtype=DType.int32
),
],
)If the model supports an input with dynamic shapes, use
None
as the dimension size inshape
. -
weights_registry – A mapping from names of model weights’ names to their values. The values are currently expected to be dlpack arrays. If an array is a read-only numpy array, the user must ensure that its lifetime extends beyond the lifetime of the model.
-
-
Returns:
The loaded model, compiled and ready to execute.
-
Raises:
RuntimeError – If the path provided is invalid.
reset_stats_report()
reset_stats_report() → None
Clears all entries in stats_report.
set_mojo_assert_level()
set_mojo_assert_level(level: str | AssertLevel)
Sets which mojo asserts are kept in the compiled model.
set_mojo_log_level()
set_mojo_log_level(level: str | LogLevel)
Sets the verbosity of mojo logging in the compiled model.
set_split_k_reduction_precision()
set_split_k_reduction_precision(precision: str | SplitKReductionPrecision)
Sets the accumulation precision for split k reductions in large matmuls.
stats_report
Metadata about model compilation (PyTorch only).
Prints a list of “fallback ops”, which are ops that could not be lowered to our internal dialect MO. Fallback ops have to be executed using the original framework (i.e. PyTorch), which makes the model much slower. This function is a good starting point for debugging model performance.
Model
class max.engine.Model
A loaded model that you can execute.
Do not instantiate this class directly. Instead, create it with
InferenceSession
.
__call__()
__call__(*args: DLPackArray | ndarray[Any, dtype[_ScalarType_co]] | Tensor | MojoValue | int | float | bool | generic, **kwargs: DLPackArray | ndarray[Any, dtype[_ScalarType_co]] | Tensor | MojoValue | int | float | bool | generic) → list[max._core.driver.Tensor | max._core.engine.MojoValue]
Call self as a function.
execute()
execute(*args: DLPackArray | ndarray[Any, dtype[_ScalarType_co]] | Tensor | MojoValue | int | float | bool | generic) → list[max._core.driver.Tensor | max._core.engine.MojoValue]
execute_legacy()
execute_legacy(**kwargs: Any) → dict[str, Union[numpy.ndarray, dict, list, tuple]]
input_metadata
property input_metadata
Metadata about the model’s input tensors, as a list of
TensorSpec
objects.
For example, you can print the input tensor names, shapes, and dtypes:
for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
for tensor in model.input_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
output_metadata
property output_metadata
Metadata about the model’s output tensors, as a list of
TensorSpec
objects.
For example, you can print the output tensor names, shapes, and dtypes:
for tensor in model.output_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
for tensor in model.output_metadata:
print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')
MojoValue
class max.engine.MojoValue
This is work in progress and you should ignore it for now.
TensorSpec
class max.engine.TensorSpec(self, shape: collections.abc.Sequence[int | None] | None, dtype: max._core.dtype.DType, name: str)
Defines the properties of a tensor, including its name, shape and data type.
For usage examples, see Model.input_metadata
.
-
Parameters:
- shape – The tensor shape.
- dtype – The tensor data type.
- name – The tensor name.
dtype
property dtype
A tensor data type.
name
property name
A tensor name.
shape
property shape
The shape of the tensor as a list of integers.
If a dimension size is unknown/dynamic (such as the batch size), its
value is None
.
TorchInputSpec
class max.engine.TorchInputSpec(self, shape: collections.abc.Sequence[int | str | None] | None, dtype: max._core.dtype.DType, device: str = '')
Specifies valid input specification for a TorchScript model.
Before you load a TorchScript model, you must create an instance of this class
for each input tensor, and pass them to the input_specs argument of
InferenceSession.load()
.ss
For example code, see InferenceSession.load()
.
-
Parameters:
- shape – The input tensor shape.
- dtype – The input data type.
- device – The device on which this tensor should be loaded.
device
property device
A torch device.
dtype
property dtype
A torch input tensor data type.
shape
property shape
The shape of the tensor as a list of integers.
If a dimension size is unknown/dynamic (such as the batch size), its
value is None
.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!