Skip to main content

Python module

attention_with_rope

An opaque KV Cache optimized attention mechanism with Rope.

AttentionWithRope

class max.nn.attention.attention_with_rope.AttentionWithRope(*, rope, sharding_strategy=None, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

Implementation of attention that uses Rotary Position Embedding (RoPE).

Initializes the attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • sharding_strategy (ShardingStrategy | None) – Optional initial sharding strategy.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
  • dtype (DType) – DType of the QKV and output projection weights.
  • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
  • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
  • stacked_qkv (bool) – Whether Q/K/V weights are stacked in a single Weight.
  • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
  • has_bias (bool) – Whether Q/K/V have bias (stacked_qkv forbids bias).
  • float8_config (Float8Config | None) – Optional Float8 config (dynamic or static).
  • clip_qkv (float | None) – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
  • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
  • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.

qkv_input_scale

property qkv_input_scale: TensorValue | None

The max of q, k, and v scale input vectors.

qkv_weight_scale

property qkv_weight_scale: TensorValue

The max of q, k, and v scale weight vectors.

rope

rope: RotaryEmbedding

shard()

shard(devices)

Create sharded views across devices (tensor-parallel).

Returns one AttentionWithRope per device with appropriately sliced weights.

Parameters:

devices (Iterable[DeviceRef])

Return type:

list[AttentionWithRope]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the Module sharding strategy.

wqkv

property wqkv: TensorValue

The concatenation of q, k, and v weight vectors.

wqkv_bias

property wqkv_bias: TensorValue | None

The concatenation of q, k, and v bias weight vectors.

AttentionWithRopeNoOpaque

class max.nn.attention.attention_with_rope.AttentionWithRopeNoOpaque(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, scale=None)

Attention with RoPE without opaque KV cache.

Assumes:
  • no float8
    • no stacked qkv
    • no bias
    • no clip_qkv
    • no float8_config

    Initializes the attention layer.

    Parameters:

    • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
    • num_attention_heads (int) – The number of attention heads.
    • num_key_value_heads (int) – Number of key/value heads.
    • hidden_size (int) – The dimension of the hidden states.
    • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
    • dtype (DType) – DType of the QKV and output projection weights.
    • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used. Use TensorParallelAttentionWithRope to use all devices during attention computation.
    • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
    • scale (float | None) – Value used to scale the results of the attention output.

    rope

    rope: RotaryEmbedding

    AttentionWithRopeQKV

    class max.nn.attention.attention_with_rope.AttentionWithRopeQKV(n_heads: 'int', kv_params: 'KVCacheParams', wq: 'TensorValueLike', wk: 'TensorValueLike', wv: 'TensorValueLike', wo: 'LinearV1', scale: 'float', rope: 'RotaryEmbedding')

    Parameters:

    rope

    rope: RotaryEmbedding

    AttentionWithRopeV1

    class max.nn.attention.attention_with_rope.AttentionWithRopeV1(n_heads, kv_params, wqkv, wo, scale, rope, bias=None, perm_idx=None, quantization_config=None)

    Implementation of attention that uses the rope frequency.

    Deprecated: Use AttentionWithRope instead.

    Parameters:

    bias

    bias: TensorValue | None = None

    perm_idx

    perm_idx: TensorValue | None = None

    quantization_config

    quantization_config: QuantizationConfig | None = None

    rope

    rope: RotaryEmbedding

    DataParallelAttentionWithRope

    class max.nn.attention.attention_with_rope.DataParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

    Data-parallel implementation of Attention with RoPE.

    This replicates the attention module across devices and runs each replica on its local inputs (x, kv, freqs_cis, input_row_offsets). No collective ops are required; KV-cache remains local to each device.

    Notes:

    • Assumes the caller has already distributed xs, kv_collections, freqs_cis, and input_row_offsets so that index i corresponds to device i, with input_row_offsets[i] rebased to start at 0.

    Initializes the attention layer.

    Parameters:

    • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
    • sharding_strategy – Optional initial sharding strategy.
    • num_attention_heads (int) – The number of attention heads.
    • num_key_value_heads (int) – Number of key/value heads.
    • hidden_size (int) – The dimension of the hidden states.
    • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
    • dtype (DType) – DType of the QKV and output projection weights.
    • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
    • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
    • stacked_qkv (bool) – Whether Q/K/V weights are stacked in a single Weight.
    • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
    • has_bias (bool) – Whether Q/K/V have bias (stacked_qkv forbids bias).
    • float8_config (Float8Config | None) – Optional Float8 config (dynamic or static).
    • clip_qkv (float | None) – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
    • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
    • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.

    GGUFQAttentionWithRope

    class max.nn.attention.attention_with_rope.GGUFQAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, quantization_encoding, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, has_bias=False, clip_qkv=None)

    Implementation of attention with GGUF quantized weights.

    Initializes the GGUF attention layer.

    Parameters:

    • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
    • num_attention_heads (int) – The number of attention heads.
    • num_key_value_heads (int) – Number of key/value heads.
    • hidden_size (int) – The dimension of the hidden states.
    • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
    • layer_idx – The layer number associated with this Attention block.
    • dtype (DType) – DType of the weights, should always be uint8.
    • devices (list[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used. Use TensorParallelAttentionWithRope to use all devices during attention computation.
    • quantization_encoding (QuantizationEncoding) – Quantization encoding of the weights.
    • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
    • scale (float | None) – Value used to scale the results of the attention output.
    • has_bias (bool) – Whether to use an attention bias.
    • clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]

    rope

    rope: RotaryEmbedding

    wqkv

    property wqkv: TensorValue

    The concatenation of q, k, and v weight vectors.

    wqkv_bias

    property wqkv_bias: TensorValue | None

    The concatenation of q, k, and v bias weight vectors.

    GPTQAttentionWithRope

    class max.nn.attention.attention_with_rope.GPTQAttentionWithRope(quantization_config, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, scale=None, linear_cls=<class 'max.nn.linear.Linear'>)

    Implementation of the GPTQ attention layer.

    Initializes the attention layer.

    Parameters:

    • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
    • sharding_strategy – Optional initial sharding strategy.
    • num_attention_heads (int) – The number of attention heads.
    • num_key_value_heads (int) – Number of key/value heads.
    • hidden_size (int) – The dimension of the hidden states.
    • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
    • dtype (DType) – DType of the QKV and output projection weights.
    • devices (list[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. If multiple are provided, the first device is used for weight placement here.
    • linear_cls (Callable[..., Linear]) – Linear class to use for projections.
    • stacked_qkv – Whether Q/K/V weights are stacked in a single Weight.
    • scale (float | None) – Optional attention scale; defaults to sqrt(1/head_dim).
    • has_bias – Whether Q/K/V have bias (stacked_qkv forbids bias).
    • float8_config – Optional Float8 config (dynamic or static).
    • clip_qkv – If provided, clamp Q/K/V weights to [-clip_qkv, clip_qkv].
    • use_qk_norm – Whether to use RMSNorm on Q/K.
    • rms_norm_eps – Value to use for numerical stability in RMSNorm.
    • quantization_config (QuantizationConfig)

    wqkv

    property wqkv: TensorValue

    The concatenation of q, k, and v weight vectors (packed + scales).

    TensorParallelAttentionWithRope

    class max.nn.attention.attention_with_rope.TensorParallelAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.linear.Linear'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None, use_qk_norm=False, rms_norm_eps=1e-06)

    Tensor-parallel wrapper that delegates sharding to the base module.

    Initializes the distributed (tensor parallel) attention layer.

    Parameters:

    • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
    • num_attention_heads (int) – The number of attention heads.
    • num_key_value_heads (int) – Number of key/value heads.
    • hidden_size (int) – The dimension of the hidden states.
    • kv_params (KVCacheParams) – KV Cache params, including number of kv heads, head dim, and dtype.
    • devices (Sequence[DeviceRef] | None) – Device(s) on which to place the weights and run the computation. Must provide at least 2 devices for tensor parallel attention.
    • dtype (DType) – DType of the QKV and output projection weights.
    • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
    • stacked_qkv (bool) – Whether the weights are stacked together.
    • scale (float | None) – Value used to scale the results of the attention output.
    • has_bias (bool) – Whether to use an attention bias.
    • float8_config (Float8Config | None) – Float8 configuration for quantization.
    • clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv].
    • use_qk_norm (bool) – Whether to use RMSNorm on Q/K.
    • rms_norm_eps (float) – Value to use for numerical stability in RMSNorm.

    distribute_value()

    max.nn.attention.attention_with_rope.distribute_value(v, devices)

    Parameters:

    Return type:

    list[TensorValue]

    Was this page helpful?