
Build an MLP block as a module
Multilayer Perceptrons (MLPs) are a fundamental component of many neural
networks. An MLP consists of a sequence of linear (fully connected) layers
interspersed with non-linear activation functions, forming the backbone of
many deep learning architectures. While you can build MLPs by manually
composing layers, creating a dedicated, reusable Module
approach for an MLP
block offers better modularity and code organization, especially in larger
projects using the MAX framework.
In this tutorial, you'll create a flexible MLP block as a custom
Module
. This block will allow
you to specify input and output dimensions, the number and size of hidden layers,
and the activation function. By the end, you'll have a reusable component you
can easily integrate into various graphs written in MAX.
Throughout this tutorial, you'll learn to define a class structure that inherits
from Module
, implement initialization and computation methods, create a helpful
string representation for debugging, and configure your module with various
parameters. These skills form the foundation for building custom, reusable
components in MAX's neural network framework.
Set up
Create a Python project to install our APIs and CLI tools:
- pip
- uv
- conda
- pixi
- Create a project folder:
mkdir max-mlp-block && cd max-mlp-block
mkdir max-mlp-block && cd max-mlp-block
- Create and activate a virtual environment:
python3 -m venv .venv/max-mlp-block \
&& source .venv/max-mlp-block/bin/activatepython3 -m venv .venv/max-mlp-block \
&& source .venv/max-mlp-block/bin/activate - Install the
modular
Python package:- Nightly
- Stable
pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/
- If you don't have it, install
uv
:curl -LsSf https://astral.sh/uv/install.sh | sh
curl -LsSf https://astral.sh/uv/install.sh | sh
Then restart your terminal to make
uv
accessible. - Create a project:
uv init max-mlp-block && cd max-mlp-block
uv init max-mlp-block && cd max-mlp-block
- Create and start a virtual environment:
uv venv && source .venv/bin/activate
uv venv && source .venv/bin/activate
- Install the
modular
Python package:- Nightly
- Stable
uv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://dl.modular.com/public/nightly/python/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/ \
--index-strategy unsafe-best-matchuv pip install modular \
--extra-index-url https://download.pytorch.org/whl/cpu \
--extra-index-url https://modular.gateway.scarf.sh/simple/ \
--index-strategy unsafe-best-match
- If you don't have it, install conda. A common choice is with
brew
:brew install miniconda
brew install miniconda
- Initialize
conda
for shell interaction:conda init
conda init
If you're on a Mac, instead use:
conda init zsh
conda init zsh
Then restart your terminal for the changes to take effect.
- Create a project:
conda create -n max-mlp-block
conda create -n max-mlp-block
- Start the virtual environment:
conda activate max-mlp-block
conda activate max-mlp-block
- Install the
modular
conda package:- Nightly
- Stable
conda install -c conda-forge -c https://conda.modular.com/max-nightly/ modular
conda install -c conda-forge -c https://conda.modular.com/max-nightly/ modular
conda install -c conda-forge -c https://conda.modular.com/max/ modular
conda install -c conda-forge -c https://conda.modular.com/max/ modular
- If you don't have it, install
pixi
:curl -fsSL https://pixi.sh/install.sh | sh
curl -fsSL https://pixi.sh/install.sh | sh
Then restart your terminal for the changes to take effect.
- Create a project:
pixi init max-mlp-block \
-c https://conda.modular.com/max-nightly/ -c conda-forge \
&& cd max-mlp-blockpixi init max-mlp-block \
-c https://conda.modular.com/max-nightly/ -c conda-forge \
&& cd max-mlp-block - Install the
modular
conda package:- Nightly
- Stable
pixi add modular
pixi add modular
pixi add "modular==25.3"
pixi add "modular==25.3"
- Start the virtual environment:
pixi shell
pixi shell
When you install the modular
package, you'll get access to the max
Python
APIs.
Define the MLPBlock
class structure
First, you'll define the basic structure of your custom module using the
Module
class. To create the
MLPBlock
module, you need to implement two methods:
__init__(self, ...)
: The constructor, where you define the sub-layers (likeLinear
) and parameters the module will use.__call__(self, x)
: The method that defines the computation graph when datax
(aTensorValue
) is passed through the module.
Start by creating a file named mlp.py
and import the necessary libraries.
Then, define the MLPBlock
class structure:
from __future__ import annotations
from typing import Callable, List, Optional
from max import nn
from max.dtype import DType
from max.graph import DeviceRef, TensorValue, ops
class MLPBlock(nn.Module):
def __init__(
# TODO: Add parameters
) -> None:
super().__init__()
from __future__ import annotations
from typing import Callable, List, Optional
from max import nn
from max.dtype import DType
from max.graph import DeviceRef, TensorValue, ops
class MLPBlock(nn.Module):
def __init__(
# TODO: Add parameters
) -> None:
super().__init__()
After importing the required components from dtype
,
graph
, and nn
(neural
network), you create your MLPBlock
class that inherits from Module
. The
__init__
method contains configuration parameters with the super().__init__()
call in the constructor.
Implement the __init__
method
Next, implement the constructor (__init__
). This method takes the configuration
parameters and creates the necessary layers and activation functions for your MLP.
You'll use Sequential
to create a
sequential container for your layers to build the computation graph.
Modify the __init__
method in your MLPBlock
class as follows:
class MLPBlock(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
hidden_features: Optional[List[int]] = None,
activation: Optional[Callable[[TensorValue], TensorValue]] = None,
) -> None:
super().__init__()
# Use empty list if no hidden features provided
hidden_features = hidden_features or []
# Default to ReLU activation if none provided
activation = activation or ops.relu
# Build the sequence of layers
layers = []
current_dim = in_features
# Add hidden layers with their activations
for i, h_dim in enumerate(hidden_features):
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=h_dim,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name=f"hidden_{i}",
)
)
layers.append(activation)
current_dim = h_dim
# Add the final output layer
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=out_features,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name="output",
)
)
# Create Sequential module with the layers
self.model = nn.Sequential(layers)
class MLPBlock(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
hidden_features: Optional[List[int]] = None,
activation: Optional[Callable[[TensorValue], TensorValue]] = None,
) -> None:
super().__init__()
# Use empty list if no hidden features provided
hidden_features = hidden_features or []
# Default to ReLU activation if none provided
activation = activation or ops.relu
# Build the sequence of layers
layers = []
current_dim = in_features
# Add hidden layers with their activations
for i, h_dim in enumerate(hidden_features):
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=h_dim,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name=f"hidden_{i}",
)
)
layers.append(activation)
current_dim = h_dim
# Add the final output layer
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=out_features,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name="output",
)
)
# Create Sequential module with the layers
self.model = nn.Sequential(layers)
Here's how this implementation works:
You initialize an empty list layers
to build your network structure. The code
iterates through hidden_features
, appending a Linear
layer instance and the provided activation
function for each hidden dimension,
updating current_dim
as you go.
After processing all hidden layers, you add the output Linear
layer. Finally,
you create a Sequential
module with
these layers, which handles the sequential application of operations.
Note that we use default values for dtype
and device
to simplify the interface.
In a production environment, you might want to expose these parameters to allow
for different data types and devices.
Implement the __call__
method
In MAX, the __call__
method is a special Python method that gets invoked when
you use a module instance as if it were a function. For example, when you write
output = mlp_block(input_tensor)
in your code, Python automatically calls the
__call__
method. This is a key part of how MAX builds computation graphs—when
you call a module with an input tensor, you're adding that module's operations to
the computation graph that will eventually be compiled and executed.
The __call__
method defines how an input
TensorValue
flows through the module,
building the computation graph. Since you're using Sequential
, your
implementation requires only two lines of code:
def __call__(self, x: TensorValue) -> TensorValue:
return self.model(x)
def __call__(self, x: TensorValue) -> TensorValue:
return self.model(x)
This method takes an input TensorValue
x
and passes it through the self.model
sequential container, which automatically applies each layer in sequence.
Implement a custom string representation
To better understand the structure of your MLP blocks when they're printed, you'll
implement the __repr__
method to display useful information about the number of
layers in each block:
def __repr__(self) -> str:
layers = list(self.model)
linear_count = sum(
1 for layer in layers if isinstance(layer, nn.Linear)
)
activation_count = len(layers) - linear_count
return f"MLPBlock({linear_count} linear layers, {activation_count} activations)"
def __repr__(self) -> str:
layers = list(self.model)
linear_count = sum(
1 for layer in layers if isinstance(layer, nn.Linear)
)
activation_count = len(layers) - linear_count
return f"MLPBlock({linear_count} linear layers, {activation_count} activations)"
This method counts the linear layers and activation functions separately, making it clear how many of each type exist in your MLP block.
Run the MLPBlock
module
Now, run the MLPBlock
by creating instances with different configurations.
Note that MAX uses a static graph representation that gets compiled before execution, which differs from frameworks like PyTorch where tensors flow dynamically through the network. The examples below show how to instantiate MLP blocks with various configurations. To actually execute these blocks with data, you would integrate them into a larger MAX graph execution context.
Create a new file called main.py
with the following code:
from max.graph import ops
from mlp import MLPBlock
if __name__ == "__main__":
print("--- Simple MLP Block ---")
# 1. Simple MLP (no hidden layers)
simple_mlp = MLPBlock(
in_features=10,
out_features=20,
hidden_features=[],
activation=ops.relu,
)
print(simple_mlp)
print("-" * 30)
# 2. MLP with one hidden layer
print("--- MLP Block (1 Hidden Layer) ---")
mlp_one_hidden = MLPBlock(
in_features=10,
out_features=5,
hidden_features=[32],
activation=ops.relu,
)
print(mlp_one_hidden)
print("-" * 30)
# 3. Deeper MLP with multiple hidden layers and GELU
print("--- Deeper MLP Block (3 Hidden Layers, GELU) ---")
deep_mlp = MLPBlock(
in_features=64,
out_features=10,
hidden_features=[128, 64, 32],
activation=ops.gelu,
)
print(deep_mlp)
print("-" * 30)
from max.graph import ops
from mlp import MLPBlock
if __name__ == "__main__":
print("--- Simple MLP Block ---")
# 1. Simple MLP (no hidden layers)
simple_mlp = MLPBlock(
in_features=10,
out_features=20,
hidden_features=[],
activation=ops.relu,
)
print(simple_mlp)
print("-" * 30)
# 2. MLP with one hidden layer
print("--- MLP Block (1 Hidden Layer) ---")
mlp_one_hidden = MLPBlock(
in_features=10,
out_features=5,
hidden_features=[32],
activation=ops.relu,
)
print(mlp_one_hidden)
print("-" * 30)
# 3. Deeper MLP with multiple hidden layers and GELU
print("--- Deeper MLP Block (3 Hidden Layers, GELU) ---")
deep_mlp = MLPBlock(
in_features=64,
out_features=10,
hidden_features=[128, 64, 32],
activation=ops.gelu,
)
print(deep_mlp)
print("-" * 30)
Execute the main.py
file. This will instantiate the MLPBlock()
with various
configurations and print their representations, showing the layers defined
within.
The following is the expected output:
--- Simple MLP Block ---
MLPBlock(1 linear layers, 0 activations)
----------------------------------------
--- MLP Block (1 Hidden Layer) ---
MLPBlock(2 linear layers, 1 activations)
----------------------------------------
--- Deeper MLP Block (3 Hidden Layers, GELU) ---
MLPBlock(4 linear layers, 3 activations)
----------------------------------------
--- Simple MLP Block ---
MLPBlock(1 linear layers, 0 activations)
----------------------------------------
--- MLP Block (1 Hidden Layer) ---
MLPBlock(2 linear layers, 1 activations)
----------------------------------------
--- Deeper MLP Block (3 Hidden Layers, GELU) ---
MLPBlock(4 linear layers, 3 activations)
----------------------------------------
- The simple MLP has 1 linear layer and 0 activations (since there are no hidden layers)
- The MLP with one hidden layer has 2 linear layers (input→hidden, hidden→output) and 1 activation
- The deeper MLP has 4 linear layers (input→hidden1, hidden1→hidden2, hidden2→hidden3, hidden3→output) and 3 activations
Conclusion
In this tutorial, you successfully created a reusable MLPBlock
module in MAX.
You learned how to:
- Define the class structure inheriting from
Module
- Implement the
__init__
method to dynamically create a sequence of linear layers (nn.Linear
) and activation functions, then wrap them innn.Sequential
- Implement a simple
__call__
method that leverages the sequential container - Create a custom string representation for debugging
- Instantiate and inspect the custom module with various configurations
This MLPBlock
provides a clean and modular way to incorporate standard MLP
structures into your MAX projects. You can now easily modify it further, perhaps
by adding layer normalization layers, or experiment with different
activation functions from max.graph.ops
. This pattern of creating reusable
modules is fundamental to building complex and maintainable models in MAX.
Did this tutorial work for you?
Thank you! We'll create more content like this.
Thank you for helping us improve!