Skip to main content

Python module

torch

CustomOpLibrary

class max.torch.CustomOpLibrary(kernel_library)

A PyTorch interface to custom operations implemented in Mojo.

This API allows for easy passing of PyTorch data as torch.Tensor values to the corresponding custom op. CustomOpLibrary handles the compilation of the Mojo custom ops and marshalling of data between PyTorch and the executable Mojo code.

For example, consider a grayscale operation implemented in Mojo:

my_library/grayscale.mojo
 @register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...
 @register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...

You can then use CustomOpLibrary to invoke the Mojo operation like so:

import torch
from max.torch import CustomOpLibrary

op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale

def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result

img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)
import torch
from max.torch import CustomOpLibrary

op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale

def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result

img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)

The custom operation produced by op_library.<opname> will have the same interface as the backing Mojo operation. Each InputTensor or OutputTensor argument corresponds to a torch.Tensor value in Python. Each argument corresponding to an OutputTensor in the Mojo operation will be modified in-place.

Parameters:

kernel_library (Path | KernelLibrary) – The path to a .mojo file or a .mojopkg with your custom op kernels, or the corresponding library object.

graph_op()

max.torch.graph_op(fn=None, name=None, kernel_library=None, input_types=None, output_types=None, num_outputs=None)

A decorator to create PyTorch custom operations using MAX graph operations.

This decorator allows you to define larger graphs using MAX graph ops or max.nn modules and call them with PyTorch tensors, or integrate them into PyTorch modules. These custom ops can be called eagerly, and support compilation with torch.compile and the Inductor backend.

The resulting custom operation uses destination-passing style, where output tensors are passed as the first arguments and modified in-place. This allows PyTorch to manage the memory and streams of the output tensors. Tensors internal to the computation are managed via MAX’s graph compiler and memory planning.

The default behavior is to JIT-compile for the specific input and output shapes needed. If you are passing variable-sized inputs, for instance a batch size or sequence length which may take on many different values between calls, you should specify this dimension as a symbolic dimension via input_types and output_types. Otherwise you will end up compiling specialized graphs for each possible variation of inputs, which may use a lot of memory.

If neither output_types nor num_outputs is specified, default to 1 output.

Example usage to create a functional-style PyTorch op backed by MAX:

import torch
import numpy as np
import max
from max.dtype import DType
from max.graph import ops

@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)

@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output

img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
import torch
import numpy as np
import max
from max.dtype import DType
from max.graph import ops

@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)

@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output

img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)

Parameters:

  • fn – The function to decorate. If None, returns a decorator.
  • name (str | None) – Optional name for the custom operation. Defaults to the function name.
  • kernel_library (Path | KernelLibrary | None) – Optional kernel library to use for compilation. Useful for creating graphs with custom Mojo ops.
  • input_types (Sequence[TensorType] | None) – Optional sequence of input tensor types for compilation. If None, types are inferred from runtime arguments.
  • output_types (Sequence[TensorType] | None) – Optional sequence of output tensor types for compilation. If None, types are inferred from runtime arguments.
  • num_outputs (int | None) – The number of outputs of the graph. We need to know this ahead of time to register with PyTorch before we’ve compiled the final kernels.

Returns:

A PyTorch custom operation that can be called with torch.Tensor arguments.