Skip to main content
Log in

Python module


Multi-layer Perceptron.


class max.nn.linear.DistributedMLP(*args, **kwargs)

A distributed multi-layer perceptron.

This class has the same state keys as the non-distributed MLP Layer.

  • Parameters:

    • dtype – DType to use for the layer weights, which should match the input dtype.
    • quantization_encoding – Quantization encoding of the layer weights.
    • hidden_dim – The last dimension of the layer input.
    • feed_forward_length – Size of dimension used to project the inputs.
    • linear_cls – Linear class to use to create the projection layers.
    • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.


class max.nn.linear.GPTQLinear(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None, perm_idx: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)

A Linear layer for GPTQ encoding


perm_idx*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None


quantization_config*: QuantizationConfig | None* = None


class max.nn.linear.GPTQLinearV2(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef | None = None, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, quantization_config: QuantizationConfig | None = None)

A Linear layer for GPTQ encoding

Initializes the linear layer with weights and optional bias with GPTQ quantization.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.
    • quantization_encoding – The quantization encoding of the weights.
    • quantization_config – Extra config for the weight quantization.


class max.nn.linear.Linear(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None)

A unified linear layer that delegates to either regular or quantized implementation.


bias*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None* = None


classmethod create(dtype: DType, quantization_encoding: QuantizationEncoding | None, in_features: int, out_features: int, weights: Weights | Weight, bias: Weights | Weight | None = None, quantization_config: QuantizationConfig | None = None) → Linear

Factory method to create a Linear layer with appropriate implementation.


weight*: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*


class max.nn.linear.LinearV2(in_dim: int, out_dim: int, dtype: DType, device: DeviceRef | None = None, has_bias: bool = False, quantization_encoding: QuantizationEncoding | None = None, name: str | None = None, clip_weight: float | None = None)

Applies a linear transformation to incoming data: y=xWT+by = xW^T + b.

This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to device.


linear_layer = LinearV2(

input_tensor: TensorValue
output = linear_layer(input_tensor)
linear_layer = LinearV2(

input_tensor: TensorValue
output = linear_layer(input_tensor)

Initializes the linear layer with weights and optional bias.

  • Parameters:

    • in_dim – The dimensionality of the input space.
    • out_dim – The dimensionality of the output space.
    • dtype – The data type for both weights and bias.
    • device – The target device for computation. Weights remain on CPU until moved during computation.
    • name – Base name for weights (appended with .weight and .bias if applicable).
    • has_bias – When True, adds a bias vector to the layer. Defaults to False.


bias*: Weight | None* = None

The optional bias vector stored on CPU with shape (out_dim,). Model init moves the bias to device if present.


device*: DeviceRef | None*

The device where matrix operations are performed.


weight*: Weight*

The weight matrix stored on CPU with shape (out_dim, in_dim). Model init transposes the weight and moves it to device.


class max.nn.linear.MLP(gate_proj: Linear, down_proj: Linear, up_proj: Linear)

Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.


down_proj*: Linear*


gate_proj*: Linear*


up_proj*: Linear*


class max.nn.linear.MLPV2(dtype: ~max.dtype.dtype.DType, quantization_encoding: ~max.graph.quantization.QuantizationEncoding | None, hidden_dim: int, feed_forward_length: int, linear_cls: ~typing.Callable[[...], ~max.nn.linear.LinearV2] = <class 'max.nn.linear.LinearV2'>, devices: ~typing.Sequence[~max.graph.type.DeviceRef] = ())

Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function.

  • Parameters:

    • dtype – DType to use for the layer weights, which should match the input dtype.
    • quantization_encoding – Quantization encoding of the layer weights.
    • hidden_dim – The last dimension of the layer input.
    • feed_forward_length – Size of dimension used to project the inputs.
    • linear_cls – Linear class to use to create the projection layers.
    • devices – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices.


class max.nn.linear.QLinear(weight: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, quantization_encoding: QuantizationEncoding | None = None)

A quantized fully connected layer.


quantization_encoding*: QuantizationEncoding | None* = None