Skip to main content

Python module

norm

GroupNorm

class max.nn.norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=gpu:0)

Group normalization block.

Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group_norm.

Parameters:

  • num_groups (int) – Number of groups to separate the channels into
  • num_channels (int) – Number of input channels
  • eps (float) – Small constant added to denominator for numerical stability
  • affine (bool) – If True, apply learnable affine transform parameters
  • device (DeviceRef)

LayerNorm

class max.nn.norm.LayerNorm(dims, device, dtype, eps=1e-05, use_bias=True)

Layer normalization block.

Parameters:

LayerNormV1

class max.nn.norm.LayerNormV1(weight, bias=None, eps=1e-06)

Layer normalization block.

Deprecated: Use LayerNorm instead.

Parameters:

bias

bias: TensorValue | None = None

eps

eps: float = 1e-06

weight

weight: TensorValue

RMSNorm

class max.nn.norm.RMSNorm(dim, dtype, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)

Computes the Root Mean Square normalization on inputs.

Parameters:

  • dim (int) – Size of last dimension of the expected input.
  • eps (float) – Value added to denominator for numerical stability.
  • weight_offset (float) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0.
  • multiply_before_cast (bool) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style).
  • dtype (DType)

shard()

shard(devices)

Creates sharded views of this RMSNorm across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded RMSNorm instances, one for each device.

Return type:

Sequence[RMSNorm]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the RMSNorm sharding strategy.

RMSNormV1

class max.nn.norm.RMSNormV1(weight, eps=1e-06, weight_offset=0.0, multiply_before_cast=True)

Computes the Root Mean Square normalization on inputs.

Deprecated: Use RMSNorm instead.

Parameters:

eps

eps: float = 1e-06

multiply_before_cast

multiply_before_cast: bool = True

weight

weight: Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | DLPackArray

weight_offset

weight_offset: float = 0.0

Was this page helpful?