Skip to main content
Log in

Build an MLP block as a module

Multilayer Perceptrons (MLPs) are a fundamental component of many neural networks. An MLP consists of a sequence of linear (fully connected) layers interspersed with non-linear activation functions, forming the backbone of many deep learning architectures. While you can build MLPs by manually composing layers, creating a dedicated, reusable Module approach for an MLP block offers better modularity and code organization, especially in larger projects using the MAX framework.

In this tutorial, you'll create a flexible MLP block as a custom Module. This block will allow you to specify input and output dimensions, the number and size of hidden layers, and the activation function. By the end, you'll have a reusable component you can easily integrate into various graphs written in MAX.

Throughout this tutorial, you'll learn to define a class structure that inherits from Module, implement initialization and computation methods, create a helpful string representation for debugging, and configure your module with various parameters. These skills form the foundation for building custom, reusable components in MAX's neural network framework.

Set up

Create a Python project to install our APIs and CLI tools:

  1. Create a project folder:
    mkdir max-mlp-block && cd max-mlp-block
    mkdir max-mlp-block && cd max-mlp-block
  2. Create and activate a virtual environment:
    python3 -m venv .venv/max-mlp-block \
    && source .venv/max-mlp-block/bin/activate
    python3 -m venv .venv/max-mlp-block \
    && source .venv/max-mlp-block/bin/activate
  3. Install the modular Python package:
    pip install modular \
    --extra-index-url https://download.pytorch.org/whl/cpu \
    --extra-index-url https://dl.modular.com/public/nightly/python/simple/
    pip install modular \
    --extra-index-url https://download.pytorch.org/whl/cpu \
    --extra-index-url https://dl.modular.com/public/nightly/python/simple/

When you install the modular package, you'll get access to the max Python APIs.

Define the MLPBlock class structure

First, you'll define the basic structure of your custom module using the Module class. To create the MLPBlock module, you need to implement two methods:

  1. __init__(self, ...): The constructor, where you define the sub-layers (like Linear) and parameters the module will use.
  2. __call__(self, x): The method that defines the computation graph when data x (a TensorValue) is passed through the module.

Start by creating a file named mlp.py and import the necessary libraries. Then, define the MLPBlock class structure:

mlp.py
from __future__ import annotations

from typing import Callable, List, Optional

from max import nn
from max.dtype import DType
from max.graph import DeviceRef, TensorValue, ops


class MLPBlock(nn.Module):

def __init__(
# TODO: Add parameters
) -> None:
super().__init__()
from __future__ import annotations

from typing import Callable, List, Optional

from max import nn
from max.dtype import DType
from max.graph import DeviceRef, TensorValue, ops


class MLPBlock(nn.Module):

def __init__(
# TODO: Add parameters
) -> None:
super().__init__()

After importing the required components from dtype, graph, and nn (neural network), you create your MLPBlock class that inherits from Module. The __init__ method contains configuration parameters with the super().__init__() call in the constructor.

Implement the __init__ method

Next, implement the constructor (__init__). This method takes the configuration parameters and creates the necessary layers and activation functions for your MLP. You'll use Sequential to create a sequential container for your layers to build the computation graph.

Modify the __init__ method in your MLPBlock class as follows:

mlp.py
class MLPBlock(nn.Module):

def __init__(
self,
in_features: int,
out_features: int,
hidden_features: Optional[List[int]] = None,
activation: Optional[Callable[[TensorValue], TensorValue]] = None,
) -> None:
super().__init__()

# Use empty list if no hidden features provided
hidden_features = hidden_features or []

# Default to ReLU activation if none provided
activation = activation or ops.relu

# Build the sequence of layers
layers = []
current_dim = in_features

# Add hidden layers with their activations
for i, h_dim in enumerate(hidden_features):
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=h_dim,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name=f"hidden_{i}",
)
)
layers.append(activation)
current_dim = h_dim

# Add the final output layer
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=out_features,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name="output",
)
)

# Create Sequential module with the layers
self.model = nn.Sequential(layers)
class MLPBlock(nn.Module):

def __init__(
self,
in_features: int,
out_features: int,
hidden_features: Optional[List[int]] = None,
activation: Optional[Callable[[TensorValue], TensorValue]] = None,
) -> None:
super().__init__()

# Use empty list if no hidden features provided
hidden_features = hidden_features or []

# Default to ReLU activation if none provided
activation = activation or ops.relu

# Build the sequence of layers
layers = []
current_dim = in_features

# Add hidden layers with their activations
for i, h_dim in enumerate(hidden_features):
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=h_dim,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name=f"hidden_{i}",
)
)
layers.append(activation)
current_dim = h_dim

# Add the final output layer
layers.append(
nn.Linear(
in_dim=current_dim,
out_dim=out_features,
dtype=DType.float32,
device=DeviceRef.CPU(),
has_bias=True,
name="output",
)
)

# Create Sequential module with the layers
self.model = nn.Sequential(layers)

Here's how this implementation works:

You initialize an empty list layers to build your network structure. The code iterates through hidden_features, appending a Linear layer instance and the provided activation function for each hidden dimension, updating current_dim as you go.

After processing all hidden layers, you add the output Linear layer. Finally, you create a Sequential module with these layers, which handles the sequential application of operations.

Note that we use default values for dtype and device to simplify the interface. In a production environment, you might want to expose these parameters to allow for different data types and devices.

Implement the __call__ method

In MAX, the __call__ method is a special Python method that gets invoked when you use a module instance as if it were a function. For example, when you write output = mlp_block(input_tensor) in your code, Python automatically calls the __call__ method. This is a key part of how MAX builds computation graphs—when you call a module with an input tensor, you're adding that module's operations to the computation graph that will eventually be compiled and executed.

The __call__ method defines how an input TensorValue flows through the module, building the computation graph. Since you're using Sequential, your implementation requires only two lines of code:

mlp.py
def __call__(self, x: TensorValue) -> TensorValue:
return self.model(x)
def __call__(self, x: TensorValue) -> TensorValue:
return self.model(x)

This method takes an input TensorValue x and passes it through the self.model sequential container, which automatically applies each layer in sequence.

Implement a custom string representation

To better understand the structure of your MLP blocks when they're printed, you'll implement the __repr__ method to display useful information about the number of layers in each block:

mlp.py
def __repr__(self) -> str:
layers = list(self.model)

linear_count = sum(
1 for layer in layers if isinstance(layer, nn.Linear)
)
activation_count = len(layers) - linear_count

return f"MLPBlock({linear_count} linear layers, {activation_count} activations)"
def __repr__(self) -> str:
layers = list(self.model)

linear_count = sum(
1 for layer in layers if isinstance(layer, nn.Linear)
)
activation_count = len(layers) - linear_count

return f"MLPBlock({linear_count} linear layers, {activation_count} activations)"

This method counts the linear layers and activation functions separately, making it clear how many of each type exist in your MLP block.

Run the MLPBlock module

Now, run the MLPBlock by creating instances with different configurations.

Note that MAX uses a static graph representation that gets compiled before execution, which differs from frameworks like PyTorch where tensors flow dynamically through the network. The examples below show how to instantiate MLP blocks with various configurations. To actually execute these blocks with data, you would integrate them into a larger MAX graph execution context.

Create a new file called main.py with the following code:

main.py
from max.graph import ops
from mlp import MLPBlock

if __name__ == "__main__":
print("--- Simple MLP Block ---")
# 1. Simple MLP (no hidden layers)
simple_mlp = MLPBlock(
in_features=10,
out_features=20,
hidden_features=[],
activation=ops.relu,
)
print(simple_mlp)
print("-" * 30)

# 2. MLP with one hidden layer
print("--- MLP Block (1 Hidden Layer) ---")
mlp_one_hidden = MLPBlock(
in_features=10,
out_features=5,
hidden_features=[32],
activation=ops.relu,
)
print(mlp_one_hidden)
print("-" * 30)

# 3. Deeper MLP with multiple hidden layers and GELU
print("--- Deeper MLP Block (3 Hidden Layers, GELU) ---")
deep_mlp = MLPBlock(
in_features=64,
out_features=10,
hidden_features=[128, 64, 32],
activation=ops.gelu,
)
print(deep_mlp)
print("-" * 30)
from max.graph import ops
from mlp import MLPBlock

if __name__ == "__main__":
print("--- Simple MLP Block ---")
# 1. Simple MLP (no hidden layers)
simple_mlp = MLPBlock(
in_features=10,
out_features=20,
hidden_features=[],
activation=ops.relu,
)
print(simple_mlp)
print("-" * 30)

# 2. MLP with one hidden layer
print("--- MLP Block (1 Hidden Layer) ---")
mlp_one_hidden = MLPBlock(
in_features=10,
out_features=5,
hidden_features=[32],
activation=ops.relu,
)
print(mlp_one_hidden)
print("-" * 30)

# 3. Deeper MLP with multiple hidden layers and GELU
print("--- Deeper MLP Block (3 Hidden Layers, GELU) ---")
deep_mlp = MLPBlock(
in_features=64,
out_features=10,
hidden_features=[128, 64, 32],
activation=ops.gelu,
)
print(deep_mlp)
print("-" * 30)

Execute the main.py file. This will instantiate the MLPBlock() with various configurations and print their representations, showing the layers defined within.

The following is the expected output:

--- Simple MLP Block ---
MLPBlock(1 linear layers, 0 activations)
----------------------------------------
--- MLP Block (1 Hidden Layer) ---
MLPBlock(2 linear layers, 1 activations)
----------------------------------------
--- Deeper MLP Block (3 Hidden Layers, GELU) ---
MLPBlock(4 linear layers, 3 activations)
----------------------------------------
--- Simple MLP Block ---
MLPBlock(1 linear layers, 0 activations)
----------------------------------------
--- MLP Block (1 Hidden Layer) ---
MLPBlock(2 linear layers, 1 activations)
----------------------------------------
--- Deeper MLP Block (3 Hidden Layers, GELU) ---
MLPBlock(4 linear layers, 3 activations)
----------------------------------------
  • The simple MLP has 1 linear layer and 0 activations (since there are no hidden layers)
  • The MLP with one hidden layer has 2 linear layers (input→hidden, hidden→output) and 1 activation
  • The deeper MLP has 4 linear layers (input→hidden1, hidden1→hidden2, hidden2→hidden3, hidden3→output) and 3 activations

Conclusion

In this tutorial, you successfully created a reusable MLPBlock module in MAX.

You learned how to:

  1. Define the class structure inheriting from Module
  2. Implement the __init__ method to dynamically create a sequence of linear layers (nn.Linear) and activation functions, then wrap them in nn.Sequential
  3. Implement a simple __call__ method that leverages the sequential container
  4. Create a custom string representation for debugging
  5. Instantiate and inspect the custom module with various configurations

This MLPBlock provides a clean and modular way to incorporate standard MLP structures into your MAX projects. You can now easily modify it further, perhaps by adding layer normalization layers, or experiment with different activation functions from max.graph.ops. This pattern of creating reusable modules is fundamental to building complex and maintainable models in MAX.

Did this tutorial work for you?