Skip to main content

Python module

group_norm

Group Normalization implementation using the graph API.

GroupNorm

class max.nn.norm.group_norm.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=cpu:0)

Group normalization block.

Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group_norm.

Parameters:

  • num_groups (int ) – Number of groups to separate the channels into
  • num_channels (int ) – Number of input channels
  • eps (float ) – Small constant added to denominator for numerical stability
  • affine (bool ) – If True, apply learnable affine transform parameters
  • device (DeviceRef )