
Write hardware-agnostic custom ops for PyTorch
When working with PyTorch models, you might encounter performance bottlenecks in specific operations that could benefit from custom optimization. You might also want to experiment with novel GPU algorithms or implement cutting-edge research ideas that aren't yet available in standard frameworks. Rather than rewriting your entire model or switching frameworks, you can write high-performance kernels in Mojo and integrate them into your existing PyTorch workflows, enabling both optimization and experimentation in your familiar development environment.
This tutorial demonstrates how to enhance a PyTorch model by implementing a custom grayscale image conversion operation in Mojo. You'll discover how to keep your familiar PyTorch development experience while unlocking the performance benefits that Mojo provides for compute-intensive operations.
In this tutorial, you'll learn to convert an RGB image into a grayscale image by integrating a custom op in Mojo and running it in PyTorch.
Set up
Let's start by creating a Python project and installing the necessary tools.
- pip
- uv
- conda
- pixi
- Create a project folder:
mkdir pytorch_custom_ops && cd pytorch_custom_ops
mkdir pytorch_custom_ops && cd pytorch_custom_ops
- Create and activate a virtual environment:
python3 -m venv .venv/pytorch_custom_ops \
&& source .venv/pytorch_custom_ops/bin/activatepython3 -m venv .venv/pytorch_custom_ops \
&& source .venv/pytorch_custom_ops/bin/activate - Install the
modular
Python package:- Nightly
- Stable
pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/
- If you don't have it, install
uv
:curl -LsSf https://astral.sh/uv/install.sh | sh
curl -LsSf https://astral.sh/uv/install.sh | sh
Then restart your terminal to make
uv
accessible. - Create a project:
uv init pytorch_custom_ops && cd pytorch_custom_ops
uv init pytorch_custom_ops && cd pytorch_custom_ops
- Create and start a virtual environment:
uv venv && source .venv/bin/activate
uv venv && source .venv/bin/activate
- Install the
modular
Python package:- Nightly
- Stable
uv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/ \
--index-strategy unsafe-best-match
- If you don't have it, install conda. A common choice is with
brew
:brew install miniconda
brew install miniconda
- Initialize
conda
for shell interaction:conda init
conda init
If you're on a Mac, instead use:
conda init zsh
conda init zsh
Then restart your terminal for the changes to take effect.
- Create a project:
conda create -n pytorch_custom_ops
conda create -n pytorch_custom_ops
- Start the virtual environment:
conda activate pytorch_custom_ops
conda activate pytorch_custom_ops
- Install the
modular
conda package:- Nightly
- Stable
conda install -c conda-forge -c https://conda.modular.com/max-nightly/ modular
conda install -c conda-forge -c https://conda.modular.com/max-nightly/ modular
conda install -c conda-forge -c https://conda.modular.com/max/ modular
conda install -c conda-forge -c https://conda.modular.com/max/ modular
- If you don't have it, install
pixi
:curl -fsSL https://pixi.sh/install.sh | sh
curl -fsSL https://pixi.sh/install.sh | sh
Then restart your terminal for the changes to take effect.
- Create a project:
pixi init pytorch_custom_ops \
-c https://conda.modular.com/max-nightly/ -c conda-forge \
&& cd pytorch_custom_opspixi init pytorch_custom_ops \
-c https://conda.modular.com/max-nightly/ -c conda-forge \
&& cd pytorch_custom_ops - Install the
modular
conda package:- Nightly
- Stable
pixi add modular
pixi add modular
pixi add "modular==25.3"
pixi add "modular==25.3"
- Start the virtual environment:
pixi shell
pixi shell
When you install the modular
package, you'll get access to the max
Python
APIs and the Mojo compiler—everything needed to build high-performance custom
operations.
Build the PyTorch interface
Let's start by creating the PyTorch side of our integration. We'll build a simple function that uses our custom operation, but first we need to establish the interface.
Create a new file called grayscale.py
and add the following code:
from pathlib import Path
import torch
@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
# We'll call our custom operation here
return output
from pathlib import Path
import torch
@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
# We'll call our custom operation here
return output
The grayscale
function transforms a color image (with red, green, and blue
channels) into a single-channel grayscale image. While PyTorch has built-in
operations for this, implementing it as a Mojo custom op demonstrates the
integration pattern you can apply to any performance-critical computation. Now we
need to bridge to our Mojo implementation.
Integrate the custom operation
The max.torch
module provides
CustomOpLibrary
, which
allows you to load and use compiled Mojo operations directly in PyTorch.
Update the grayscale.py
file to include the following code:
from pathlib import Path
import torch
from max.torch import CustomOpLibrary
# Load the compiled Mojo package containing our custom operations
mojo_kernels = Path(__file__).parent / "operations"
ops = CustomOpLibrary(mojo_kernels)
@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
ops.grayscale(output, pic) # Call our Mojo custom op
return output
from pathlib import Path
import torch
from max.torch import CustomOpLibrary
# Load the compiled Mojo package containing our custom operations
mojo_kernels = Path(__file__).parent / "operations"
ops = CustomOpLibrary(mojo_kernels)
@torch.compile
def grayscale(pic):
output = pic.new_empty(pic.shape[:-1]) # Remove color channel dimension
ops.grayscale(output, pic) # Call our Mojo custom op
return output
The CustomOpLibrary
loads operations from a compiled Mojo package file
(.mojopkg
). Once loaded, you can call these operations just like any other
PyTorch function. The operations automatically handle data movement between
PyTorch and MAX, and integrate with PyTorch when needed.
The compilation of your Mojo code into a .mojopkg
file is handled when you run
your Python script. You don't need to manually invoke the Mojo compiler or manage
build steps.
Implement the Mojo kernel
Now for the core implementation—our high-performance Mojo kernel. We'll create a Mojo package that defines our grayscale conversion operation
Create a new file called grayscale.mojo
inside the operations
folder and add
the following code:
from compiler import register
from max.tensor import InputTensor, OutputTensor, foreach
from runtime.asyncrt import DeviceContextPtr
from utils.index import IndexList
@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
@parameter
@always_inline
fn color_to_grayscale[
simd_width: Int
](idx: IndexList[img_out.rank]) -> SIMD[DType.uint8, simd_width]:
@parameter
fn load(idx: IndexList[img_in.rank]) -> SIMD[DType.float32, simd_width]:
return img_in.load[simd_width](idx).cast[DType.float32]()
var row = idx[0]
var col = idx[1]
# Load RGB values
var r = load(IndexList[3](row, col, 0))
var g = load(IndexList[3](row, col, 1))
var b = load(IndexList[3](row, col, 2))
# Apply standard grayscale conversion formula
var gray = 0.21 * r + 0.71 * g + 0.07 * b
return min(gray, 255).cast[DType.uint8]()
foreach[color_to_grayscale, target=target, simd_width=1](img_out, ctx)
from compiler import register
from max.tensor import InputTensor, OutputTensor, foreach
from runtime.asyncrt import DeviceContextPtr
from utils.index import IndexList
@register("grayscale")
struct Grayscale:
@staticmethod
fn execute[
target: StaticString,
](
img_out: OutputTensor[dtype = DType.uint8, rank=2],
img_in: InputTensor[dtype = DType.uint8, rank=3],
ctx: DeviceContextPtr,
) raises:
@parameter
@always_inline
fn color_to_grayscale[
simd_width: Int
](idx: IndexList[img_out.rank]) -> SIMD[DType.uint8, simd_width]:
@parameter
fn load(idx: IndexList[img_in.rank]) -> SIMD[DType.float32, simd_width]:
return img_in.load[simd_width](idx).cast[DType.float32]()
var row = idx[0]
var col = idx[1]
# Load RGB values
var r = load(IndexList[3](row, col, 0))
var g = load(IndexList[3](row, col, 1))
var b = load(IndexList[3](row, col, 2))
# Apply standard grayscale conversion formula
var gray = 0.21 * r + 0.71 * g + 0.07 * b
return min(gray, 255).cast[DType.uint8]()
foreach[color_to_grayscale, target=target, simd_width=1](img_out, ctx)
First, the @register("grayscale")
decorator makes this operation available to
PyTorch under the name grayscale
. This is the ops.grayscale(output, pic)
function called in the PyTorch model.
Then, the foreach
primitive automatically parallelizes the operation across available compute
units.
The color_to_grayscale
function takes an index list and returns a SIMD vector
of the grayscale value. The foreach
primitive automatically parallelizes the
operation across available compute units.
Finally, the target
parameter allows the same code to run on both CPU and GPU.
This approach allows Mojo to handle the low-level optimization details: memory layout, vectorization, and parallelization, while you focus on the algorithm.
Run the example
Now that we have a PyTorch model and a Mojo kernel, we can test it with a real image.
Create a new file called main.py
to the root of your project and add the
following code:
import io
import numpy as np
from PIL import Image
# Specify a test image
image = Image.open("test_image.jpg")
# Convert to PyTorch tensor and move to GPU for processing
image_tensor = torch.from_dlpack(np.array(image)).cuda()
# Apply our custom grayscale operation
gray_image = grayscale(image_tensor)
# Convert back to PIL Image for display or further processing
result = Image.fromarray(gray_image.cpu().numpy())
import io
import numpy as np
from PIL import Image
# Specify a test image
image = Image.open("test_image.jpg")
# Convert to PyTorch tensor and move to GPU for processing
image_tensor = torch.from_dlpack(np.array(image)).cuda()
# Apply our custom grayscale operation
gray_image = grayscale(image_tensor)
# Convert back to PIL Image for display or further processing
result = Image.fromarray(gray_image.cpu().numpy())
You can run the example with the following command:
python main.py
python main.py
The following image shows the result of the grayscale operation.

This example demonstrates the complete pipeline: downloading an image, converting it to a PyTorch tensor, processing it with our Mojo custom op, and converting back to a standard Python image format. The operation automatically leverages GPU acceleration when available, providing significant performance improvements over CPU-only implementations.
Next steps
Now that you wrote a Mojo kernel with your PyTorch model, check out out these other tutorials:
Did this tutorial work for you?
Thank you! We'll create more content like this.
Thank you for helping us improve!