Python module
attention.multi_latent_attention
An opaque KV Cache optimized attention mechanism with Rope.
DistributedLatentAttentionWithRope
class max.nn.attention.multi_latent_attention.DistributedLatentAttentionWithRope(**kwargs)
Distributed implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause KV-cache to be duplicated across devices, which is not efficient.
Initializes the latent attention layer.
-
Parameters:
-
- rope – The rope layer to borrow the freqs_cis value from.
- num_attention_heads – The number of attention heads.
- num_key_value_heads – Number of key/value heads.
- hidden_size – The dimension of the hidden states.
- kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype – DType of the weights, currently only bfloat16 is supported.
- devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls – Linear class to use for the outputs dense layer.
- scale – Value used to scale the results of the attention output.
- q_lora_rank – Optional LoRA rank for Q projection.
- kv_lora_rank – LoRA rank for KV projections.
- qk_nope_head_dim – Head dimension for non-positional encoding part.
- qk_rope_head_dim – Head dimension for rope part.
- v_head_dim – Head dimension for value.
- buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.
LatentAttentionWithRope
class max.nn.attention.multi_latent_attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384)
Implementation of Latent Attention with Rope.
Initializes the latent attention layer.
-
Parameters:
-
- rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
- num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – Number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype (DType) – DType of the weights, currently only bfloat16 is supported.
- devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
- linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
- scale (float | None) – Value used to scale the results of the attention output.
- q_lora_rank (int | None) – Optional LoRA rank for Q projection.
- kv_lora_rank (int) – LoRA rank for KV projections.
- qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
- qk_rope_head_dim (int) – Head dimension for rope part.
- v_head_dim (int) – Head dimension for value.
- buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.
rope
rope: RotaryEmbedding
shard()
shard(devices)
Creates sharded views of this Module across multiple devices.
sharding_strategy
property sharding_strategy: ShardingStrategy | None
Get the Module sharding strategy.
w_uk_uv
property w_uk_uv: list[TensorValue]
The concatenation of q, k, and v weight vectors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!