Skip to main content
Log in

Python module

linear

Multi-layer Perceptron.

ColumnParallelLinear

class max.nn.linear.ColumnParallelLinear(*args, devices: Sequence[DeviceRef], **kwargs)

A Linear layer where the weight and bias are sharded onto multiple devices.

This layer first computes y=xWiT+biy = xW_i^T + b_i for each device i in [0,…, num_devices]:

+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+
+-----+       +-----+ T     +-----+       +-----+
| | | W_0 | | b_0 | | y_0 | GPU0
| | +-----+ +-----+ +-----+
| | | W_1 | | b_1 | | y_1 | GPU1
| x | @ +-----+ + +-----+ = +-----+
| | | W_2 | | b_2 | | y_2 | GPU2
| | +-----+ +-----+ +-----+
| | | W_3 | | b_3 | | y_3 | GPU3
+-----+ +-----+ +-----+ +-----+

The values are then collected using an Allgather op, producing the same output tensor y=xWT+by = xW^T + b on each device:

GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
GPU0  GPU1  GPU2  GPU3                      GPU0  GPU1  GPU2  GPU3
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 |
+-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+
| - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+
| - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 |
+-----+-----+-----+-----+ +-----+-----+-----+-----+

Example usage:

from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)
from max.dtype import DType
from max.graph import DeviceRef
from max.nn import ColumnParallelLinear

num_devices = 4
distributed_linear = ColumnParallelLinear(
in_dim,
out_dim,
DType.float32,
devices=[DeviceRef.GPU(i) for i in range(num_devices)],
)

Initializes the linear layer with weights and optional bias.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • name – Base name for weights (appended with .weight and .bias if applicable).
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.

DistributedMLP

class max.nn.linear.DistributedMLP(*args, **kwargs)

A distributed multi-layer perceptron.

This class has the same state keys as the non-distributed MLP Layer.

  • Parameters:

    • dtype – DType to use for the layer weights, which should match the input dtype.
    • quantization_encoding – Quantization encoding of the layer weights.
    • hidden_dim – The last dimension of the layer input.
    • feed_forward_length – Size of dimension used to project the inputs.
    • linear_cls – Linear class to use to create the projection layers.
    • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
    • activation_function – Activation function to use. Options are:
      • “silu”
      • “gelu”
      • “gelu_tanh”
      • “relu”
      • “tanh”
      • “sigmoid”

GPTQLinear

class max.nn.linear.GPTQLinear(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None, perm_idx: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)

A Linear layer for GPTQ encoding

perm_idx

perm_idx*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None

quantization_config

quantization_config*: QuantizationConfig | None* = None

GPTQLinearV2

class max.nn.linear.GPTQLinearV2(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef | None = None, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None)

A Linear layer for GPTQ encoding

Initializes the linear layer with weights and optional bias with GPTQ quantization.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.
    • quantization_encoding – The quantization encoding of the weights.
    • quantization_config – Extra config for the weight quantization.

Linear

class max.nn.linear.Linear(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)

A unified linear layer that delegates to either regular or quantized implementation.

bias

bias*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None

create()

classmethod create(dtype: DType, quantization_encoding: QuantizationEncoding | None, in_features: int, out_features: int, weights: Weights | Weight, bias: Weights | Weight | None = None, quantization_config: QuantizationConfig | None = None) → Linear

Factory method to create a Linear layer with appropriate implementation.

weight

weight*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*

LinearV2

class max.nn.linear.LinearV2(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef | None = None, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, name: str | None = None, clip_weight: float | None = None)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to device.

Example:

linear_layer = LinearV2(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = LinearV2(
in_dim=256,
out_dim=128,
dtype=DType.float32,
device=DeviceRef.GPU(),
name="linear",
has_bias=True
)

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • name – Base name for weights (appended with .weight and .bias if applicable).
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.

bias

bias*: Weight | None* = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to device if present.

device

device*: DeviceRef*

The device where matrix operations are performed.

weight

weight*: Weight*

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to device.

MLP

class max.nn.linear.MLP(gate_proj: Linear, down_proj: Linear, up_proj: Linear)

Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.

down_proj

down_proj*: Linear*

gate_proj

gate_proj*: Linear*

up_proj

up_proj*: Linear*

MLPV2

class max.nn.linear.MLPV2(dtype: ~max._core.dtype.DType, quantization_encoding: ~max.graph.quantization.QuantizationEncoding | None, hidden_dim: int, feed_forward_length: int, linear_cls: ~typing.Callable[[...], ~max.nn.linear.LinearV2] = <class 'max.nn.linear.LinearV2'>, devices: ~collections.abc.Sequence[~max.graph.type.DeviceRef] = (), activation_function: str = 'silu')

Simple multi-layer perceptron composed of three linear layers. Defaults to SiLU activation function.

  • Parameters:

    • dtype – DType to use for the layer weights, which should match the input dtype.
    • quantization_encoding – Quantization encoding of the layer weights.
    • hidden_dim – The last dimension of the layer input.
    • feed_forward_length – Size of dimension used to project the inputs.
    • linear_cls – Linear class to use to create the projection layers.
    • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.
    • activation_function – Activation function to use. Options are:
      • “silu”
      • “gelu”
      • “gelu_tanh”
      • “relu”
      • “tanh”
      • “sigmoid”

QLinear

class max.nn.linear.QLinear(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None)

A quantized fully connected layer.

quantization_encoding

quantization_encoding*: QuantizationEncoding | None* = None