Skip to main content
Log in

Python module

ragged_attention

An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel.

RaggedAttention

class max.nn.attention.ragged_attention.RaggedAttention(*, mask_variant: ~max.nn.kernels.MHAMaskVariant, num_attention_heads: int, num_key_value_heads: int, hidden_size: int, kv_params: ~max.nn.kv_cache.cache_params.KVCacheParams, layer_idx: int, devices: list[max.graph.type.DeviceRef] | None = None, dtype: ~max._core.dtype.DType = DType.float32, linear_cls: ~typing.Callable[[...], ~max.nn.linear.LinearV2] = <class 'max.nn.linear.LinearV2'>, stacked_qkv: bool = False, scale: float | None = None, has_bias: bool = False, clip_qkv: float | None = None)

Layer that computes the self attention score for ragged inputs.

Initializes the attention layer.

  • Parameters:

    • rope – The rope layer to borrow the freq_cis value from.
    • num_attention_heads – The number of attention heads.
    • num_key_value_heads – Number of key/value heads.
    • hidden_size – The dimension of the hidden states.
    • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
    • layer_idx – The layer number associated with this Attention block.
    • dtype – DType of the
    • devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
    • linear_cls – Linear class to use for the outputs dense layer.
    • stacked_qkv – Whether the weights are stacked together.
    • scale – Value used to scale the results of the attention output.
    • has_bias – Whether to use an attention bias.
    • clip_qkv – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]

wqkv

property wqkv*: TensorValue*

The concatenation of q, k, and v weight vectors.