Skip to main content
Log in

Python module

attention_with_rope

An opaque KV Cache optimized attention mechanism with Rope.

AttentionWithRope

class max.pipelines.nn.attention.attention_with_rope.AttentionWithRope(n_heads: 'int', kv_params: 'KVCacheParams', layer_idx: 'TensorValue', wqkv: 'TensorValue', wo: 'Linear', rope: 'OptimizedRotaryEmbedding', bias: 'Optional[TensorValue]' = None, perm_idx: 'Optional[TensorValue]' = None, quantization_config: 'Optional[QuantizationConfig]' = None)

bias

bias*: TensorValue | None* = None

perm_idx

perm_idx*: TensorValue | None* = None

quantization_config

quantization_config*: QuantizationConfig | None* = None

rope

rope*: OptimizedRotaryEmbedding*

AttentionWithRopeQKV

class max.pipelines.nn.attention.attention_with_rope.AttentionWithRopeQKV(n_heads: 'int', kv_params: 'KVCacheParams', layer_idx: 'int', wq: 'TensorValueLike', wk: 'TensorValueLike', wv: 'TensorValueLike', wo: 'Linear', rope: 'OptimizedRotaryEmbedding')

rope

rope*: OptimizedRotaryEmbedding*

AttentionWithRopeV2

class max.pipelines.nn.attention.attention_with_rope.AttentionWithRopeV2(*, rope: ~max.pipelines.nn.rotary_embedding.OptimizedRotaryEmbedding, num_attention_heads: int, num_key_value_heads: int, hidden_size: int, kv_params: ~max.pipelines.kv_cache.cache_params.KVCacheParams, layer_idx: int, dtype: ~max.dtype.dtype.DType = DType.float32, device: ~max.graph.type.DeviceRef = cpu:0, linear_cls: ~typing.Callable[[...], ~max.pipelines.nn.linear.LinearV2] = <class 'max.pipelines.nn.linear.LinearV2'>, stacked_qkv: bool = False)

Implementation of attention that uses the rope frequency.

AttentionWithRopeV2 will replace AttentionWithRope as we roll out the new Layer API.

Initializes the attention layer.

  • Parameters:

    • rope – The rope layer to borrow the freq_cis value from.
    • num_attention_heads – The number of attention heads.
    • num_key_value_heads – Number of key/value heads.
    • hidden_size – The dimension of the hidden states.
    • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
    • layer_idx – The layer number associated with this Attention block.
    • dtype – DType of the
    • device – Device to place the weights and run the computation.
    • linear_cls – Linear class to use for the outputs dense layer.
    • stacked_qkv – Whether the weights are stacked together.

rope

rope*: OptimizedRotaryEmbedding*

wqkv

property wqkv*: TensorValue*

The concatenation of q, k, and v weight vectors.

DistributedAttentionWithRope

class max.pipelines.nn.attention.attention_with_rope.DistributedAttentionWithRope(list_of_attentions: 'List[AttentionWithRope]', devices: 'list[DeviceRef]')

devices

devices*: list[max.graph.type.DeviceRef]*

list_of_attentions

list_of_attentions*: List[AttentionWithRope]*

GPTQAttentionWithRope

class max.pipelines.nn.attention.attention_with_rope.GPTQAttentionWithRope(quantization_config: ~max.graph.quantization.QuantizationConfig, rope: ~max.pipelines.nn.rotary_embedding.OptimizedRotaryEmbedding, num_attention_heads: int, num_key_value_heads: int, hidden_size: int, kv_params: ~max.pipelines.kv_cache.cache_params.KVCacheParams, layer_idx: int, dtype: ~max.dtype.dtype.DType = DType.float32, device: ~max.graph.type.DeviceRef | None = None, linear_cls: ~typing.Callable[[...], ~max.pipelines.nn.linear.LinearV2] = <class 'max.pipelines.nn.linear.LinearV2'>)

Implementation of the GPT-Q attention layer.

Initializes the attention layer.

  • Parameters:

    • rope – The rope layer to borrow the freq_cis value from.
    • num_attention_heads – The number of attention heads.
    • num_key_value_heads – Number of key/value heads.
    • hidden_size – The dimension of the hidden states.
    • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
    • layer_idx – The layer number associated with this Attention block.
    • dtype – DType of the
    • device – Device to place the weights and run the computation.
    • linear_cls – Linear class to use for the outputs dense layer.
    • stacked_qkv – Whether the weights are stacked together.

wqkv

property wqkv*: TensorValue*

The concatenation of q, k, and v weight vectors.

distribute_value()

max.pipelines.nn.attention.attention_with_rope.distribute_value(v: TensorValue, devices: List[DeviceRef]) → List[TensorValue]