Skip to main content
Log in

Python module

rms_norm

Normalization layer.

DistributedRMSNorm

class max.pipelines.nn.norm.rms_norm.DistributedRMSNorm(rms_norms: list[max.pipelines.nn.norm.rms_norm.RMSNorm], devices: list[max.graph.type.DeviceRef])

devices

devices*: list[max.graph.type.DeviceRef]*

rms_norms

rms_norms*: list[max.pipelines.nn.norm.rms_norm.RMSNorm]*

RMSNorm

class max.pipelines.nn.norm.rms_norm.RMSNorm(weight: max._mlir._mlir_libs._mlir.ir.Value | max.graph.value.TensorValue | max.graph.type.Shape | max.graph.type.Dim | int | float | numpy.integer | numpy.floating | numpy.ndarray, eps: float = 1e-06)

eps

eps*: float* = 1e-06

weight

weight*: Value | TensorValue | Shape | Dim | int | float | integer | floating | ndarray*