Skip to main content

Write hardware-agnostic custom ops for PyTorch

When working with PyTorch models, you might encounter performance bottlenecks in specific operations that could benefit from custom optimization. You might also want to experiment with novel GPU algorithms or implement cutting-edge research ideas that aren't yet available in standard frameworks. Rather than rewriting your entire model or switching frameworks, you can write high-performance kernels in Mojo and integrate them into your existing PyTorch workflows, enabling both optimization and experimentation in your familiar development environment.

This tutorial demonstrates how to enhance a PyTorch model by implementing a custom grayscale image conversion operation in Mojo. You'll discover how to keep your familiar PyTorch development experience while unlocking the performance benefits that Mojo provides for compute-intensive operations.

In this tutorial, you'll learn to convert an RGB image into a grayscale image by integrating a custom op in Mojo and running it in PyTorch.

Set up

Let's start by creating a Python project and installing the necessary tools.

  1. Create a project folder:
    mkdir pytorch_custom_ops && cd pytorch_custom_ops
    mkdir pytorch_custom_ops && cd pytorch_custom_ops
  2. Create and activate a virtual environment:
    python3 -m venv .venv/pytorch_custom_ops \
    && source .venv/pytorch_custom_ops/bin/activate
    python3 -m venv .venv/pytorch_custom_ops \
    && source .venv/pytorch_custom_ops/bin/activate
  3. Install the modular Python package:
    pip install modular \
    --extra-index-url https://download.pytorch.org/whl/cpu \
    --extra-index-url https://dl.modular.com/public/nightly/python/simple/
    pip install modular \
    --extra-index-url https://download.pytorch.org/whl/cpu \
    --extra-index-url https://dl.modular.com/public/nightly/python/simple/

When you install the modular package, you'll get access to the max Python APIs and the Mojo compiler—everything needed to build high-performance custom operations.

Build the PyTorch interface

Let's start by creating the PyTorch side of our integration. We'll build a simple function that uses our custom operation, but first we need to establish the interface.

Create a new file called grayscale.py and add the following code:

🐍 grayscale.py
from pathlib import Path
import torch

@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
# We'll call our custom operation here
return output
from pathlib import Path
import torch

@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
# We'll call our custom operation here
return output

The grayscale function transforms a color image (with red, green, and blue channels) into a single-channel grayscale image. While PyTorch has built-in operations for this, implementing it as a Mojo custom op demonstrates the integration pattern you can apply to any performance-critical computation. Now we need to bridge to our Mojo implementation.

Integrate the custom operation

The max.torch module provides CustomOpLibrary, which allows you to load and use compiled Mojo operations directly in PyTorch.

Update the grayscale.py file to include the following code:

🐍 grayscale.py
from pathlib import Path
import torch
from max.torch import CustomOpLibrary

# Load the compiled Mojo package containing our custom operations
mojo_kernels = Path(__file__).parent / "operations"
ops = CustomOpLibrary(mojo_kernels)

@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
ops.grayscale(output, pic) # Call our Mojo custom op
return output
from pathlib import Path
import torch
from max.torch import CustomOpLibrary

# Load the compiled Mojo package containing our custom operations
mojo_kernels = Path(__file__).parent / "operations"
ops = CustomOpLibrary(mojo_kernels)

@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
ops.grayscale(output, pic) # Call our Mojo custom op
return output

The CustomOpLibrary loads operations from a compiled Mojo package file (.mojopkg). Once loaded, you can call these operations just like any other PyTorch function. The operations automatically handle data movement between PyTorch and MAX, and integrate with PyTorch when needed.

The compilation of your Mojo code into a .mojopkg file is handled when you run your Python script. You don't need to manually invoke the Mojo compiler or manage build steps.

Implement the Mojo kernel

Now for the core implementation—our high-performance Mojo kernel. We'll create a Mojo package that defines our grayscale conversion operation

Create a new file called grayscale.mojo inside the operations folder and add the following code:

🔥 grayscale.mojo
from compiler import register
from max.tensor import InputTensor, OutputTensor, foreach
from runtime.asyncrt import DeviceContextPtr
from utils.index import IndexList

@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
@parameter
@always_inline
fn color_to_grayscale[
simd_width: Int
](idx: IndexList[img_out.rank]) -> SIMD[DType.uint8, simd_width]:
@parameter
fn load(idx: IndexList[img_in.rank]) -> SIMD[DType.float32, simd_width]:
return img_in.load[simd_width](idx).cast[DType.float32]()

var row = idx[0]
var col = idx[1]

# Load RGB values
var r = load(IndexList[3](row, col, 0))
var g = load(IndexList[3](row, col, 1))
var b = load(IndexList[3](row, col, 2))

# Apply standard grayscale conversion formula
var gray = 0.21 * r + 0.71 * g + 0.07 * b
return min(gray, 255).cast[DType.uint8]()

foreach[color_to_grayscale, target=target, simd_width=1](img_out, ctx)
from compiler import register
from max.tensor import InputTensor, OutputTensor, foreach
from runtime.asyncrt import DeviceContextPtr
from utils.index import IndexList

@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
@parameter
@always_inline
fn color_to_grayscale[
simd_width: Int
](idx: IndexList[img_out.rank]) -> SIMD[DType.uint8, simd_width]:
@parameter
fn load(idx: IndexList[img_in.rank]) -> SIMD[DType.float32, simd_width]:
return img_in.load[simd_width](idx).cast[DType.float32]()

var row = idx[0]
var col = idx[1]

# Load RGB values
var r = load(IndexList[3](row, col, 0))
var g = load(IndexList[3](row, col, 1))
var b = load(IndexList[3](row, col, 2))

# Apply standard grayscale conversion formula
var gray = 0.21 * r + 0.71 * g + 0.07 * b
return min(gray, 255).cast[DType.uint8]()

foreach[color_to_grayscale, target=target, simd_width=1](img_out, ctx)

First, the @register("grayscale") decorator makes this operation available to PyTorch under the name grayscale. This is the ops.grayscale(output, pic) function called in the PyTorch model.

Then, the foreach primitive automatically parallelizes the operation across available compute units.

The color_to_grayscale function takes an index list and returns a SIMD vector of the grayscale value. The foreach primitive automatically parallelizes the operation across available compute units.

Finally, the target parameter allows the same code to run on both CPU and GPU.

This approach allows Mojo to handle the low-level optimization details: memory layout, vectorization, and parallelization, while you focus on the algorithm.

Run the example

Now that we have a PyTorch model and a Mojo kernel, we can test it with a real image.

Create a new file called main.py to the root of your project and add the following code:

🐍 main.py
import io
import numpy as np
from PIL import Image

# Specify a test image
image = Image.open("test_image.jpg")
# Convert to PyTorch tensor and move to GPU for processing
image_tensor = torch.from_dlpack(np.array(image)).cuda()

# Apply our custom grayscale operation
gray_image = grayscale(image_tensor)

# Convert back to PIL Image for display or further processing
result = Image.fromarray(gray_image.cpu().numpy())
import io
import numpy as np
from PIL import Image

# Specify a test image
image = Image.open("test_image.jpg")
# Convert to PyTorch tensor and move to GPU for processing
image_tensor = torch.from_dlpack(np.array(image)).cuda()

# Apply our custom grayscale operation
gray_image = grayscale(image_tensor)

# Convert back to PIL Image for display or further processing
result = Image.fromarray(gray_image.cpu().numpy())

You can run the example with the following command:

python main.py
python main.py

The following image shows the result of the grayscale operation.

This example demonstrates the complete pipeline: downloading an image, converting it to a PyTorch tensor, processing it with our Mojo custom op, and converting back to a standard Python image format. The operation automatically leverages GPU acceleration when available, providing significant performance improvements over CPU-only implementations.

Next steps

Now that you wrote a Mojo kernel with your PyTorch model, check out out these other tutorials:

Did this tutorial work for you?