Python module
torch
CustomOpLibrary
class max.torch.CustomOpLibrary(kernel_library)
A PyTorch interface to custom operations implemented in Mojo.
This API allows for easy passing of PyTorch data as
torch.Tensor
values to the corresponding custom op. CustomOpLibrary
handles the compilation of the Mojo custom ops and marshalling of data between
PyTorch and the executable Mojo code.
For example, consider a grayscale operation implemented in Mojo:
@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...
@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
# The kind of device this is running on: "cpu" or "gpu"
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
...
You can then use CustomOpLibrary
to invoke the Mojo operation like so:
import torch
from max.torch import CustomOpLibrary
op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale
def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result
img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)
import torch
from max.torch import CustomOpLibrary
op_library = CustomOpLibrary("my_library")
grayscale_op = op_library.grayscale
def grayscale(pic: torch.Tensor) -> torch.Tensor:
result = pic.new_empty(pic.shape[:-1])
grayscale_op(result, pic)
return result
img = (torch.rand(64, 64, 3) * 255).to(torch.uint8)
result = grayscale(img)
The custom operation produced by op_library.<opname>
will have the
same interface as the backing Mojo operation. Each InputTensor
or
OutputTensor
argument corresponds to a
torch.Tensor
value in Python. Each argument corresponding to an OutputTensor
in the
Mojo operation will be modified in-place.
-
Parameters:
-
kernel_library (Path | KernelLibrary) – The path to a
.mojo
file or a.mojopkg
with your custom op kernels, or the corresponding library object.
graph_op()
max.torch.graph_op(fn=None, name=None, kernel_library=None, input_types=None, output_types=None, num_outputs=None)
A decorator to create PyTorch custom operations using MAX graph operations.
This decorator allows you to define larger graphs using MAX graph ops or max.nn
modules and call them with PyTorch tensors, or integrate them into PyTorch
modules. These custom ops can be called eagerly, and support compilation with
torch.compile
and the Inductor backend.
The resulting custom operation uses destination-passing style, where output tensors are passed as the first arguments and modified in-place. This allows PyTorch to manage the memory and streams of the output tensors. Tensors internal to the computation are managed via MAX’s graph compiler and memory planning.
The default behavior is to JIT-compile for the specific input and output shapes needed. If you are passing variable-sized inputs, for instance a batch size or sequence length which may take on many different values between calls, you should specify this dimension as a symbolic dimension via input_types and output_types. Otherwise you will end up compiling specialized graphs for each possible variation of inputs, which may use a lot of memory.
If neither output_types nor num_outputs is specified, default to 1 output.
Example usage to create a functional-style PyTorch op backed by MAX:
import torch
import numpy as np
import max
from max.dtype import DType
from max.graph import ops
@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)
@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output
img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
import torch
import numpy as np
import max
from max.dtype import DType
from max.graph import ops
@max.torch.graph_op
def max_grayscale(pic: max.graph.TensorValue):
scaled = pic.cast(DType.float32) * np.array([0.21, 0.71, 0.07])
grayscaled = ops.sum(scaled, axis=-1).cast(pic.dtype)
# max reductions don't remove the dimension, need to squeeze
return ops.squeeze(grayscaled, axis=-1)
@torch.compile
def grayscale(pic: torch.Tensor):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
max_grayscale(output, pic) # Call as destination-passing style
return output
img = (torch.rand(64, 64, 3, device=device) * 255).to(torch.uint8)
result = grayscale(img)
-
Parameters:
-
- fn – The function to decorate. If None, returns a decorator.
- name (str | None) – Optional name for the custom operation. Defaults to the function name.
- kernel_library (Path | KernelLibrary | None) – Optional kernel library to use for compilation. Useful for creating graphs with custom Mojo ops.
- input_types (Sequence[TensorType] | None) – Optional sequence of input tensor types for compilation. If None, types are inferred from runtime arguments.
- output_types (Sequence[TensorType] | None) – Optional sequence of output tensor types for compilation. If None, types are inferred from runtime arguments.
- num_outputs (int | None) – The number of outputs of the graph. We need to know this ahead of time to register with PyTorch before we’ve compiled the final kernels.
-
Returns:
-
A PyTorch custom operation that can be called with torch.Tensor arguments.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!