Skip to main content

Python module

attention.multi_latent_attention

An opaque KV Cache optimized attention mechanism with Rope.

DistributedLatentAttentionWithRope

class max.nn.attention.multi_latent_attention.DistributedLatentAttentionWithRope(**kwargs)

Distributed implementation of the Latent Attention with Rope. Note that using tensor parallelism for MLA will cause KV-cache to be duplicated across devices, which is not efficient.

Initializes the latent attention layer.

Parameters:

  • rope – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads – The number of attention heads.
  • num_key_value_heads – Number of key/value heads.
  • hidden_size – The dimension of the hidden states.
  • kv_params – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype – DType of the weights, currently only bfloat16 is supported.
  • devices – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls – Linear class to use for the outputs dense layer.
  • scale – Value used to scale the results of the attention output.
  • q_lora_rank – Optional LoRA rank for Q projection.
  • kv_lora_rank – LoRA rank for KV projections.
  • qk_nope_head_dim – Head dimension for non-positional encoding part.
  • qk_rope_head_dim – Head dimension for rope part.
  • v_head_dim – Head dimension for value.
  • buffer_size – Buffer size for storing the temporal results during prefill, in unit of tokens.

LatentAttentionWithRope

class max.nn.attention.multi_latent_attention.LatentAttentionWithRope(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, dtype, devices=None, linear_cls=<class 'max.nn.linear.Linear'>, scale=None, q_lora_rank=None, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, buffer_size=16384)

Implementation of Latent Attention with Rope.

Initializes the latent attention layer.

Parameters:

  • rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
  • num_attention_heads (int) – The number of attention heads.
  • num_key_value_heads (int) – Number of key/value heads.
  • hidden_size (int) – The dimension of the hidden states.
  • kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
  • dtype (DType) – DType of the weights, currently only bfloat16 is supported.
  • devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used.
  • linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
  • scale (float | None) – Value used to scale the results of the attention output.
  • q_lora_rank (int | None) – Optional LoRA rank for Q projection.
  • kv_lora_rank (int) – LoRA rank for KV projections.
  • qk_nope_head_dim (int) – Head dimension for non-positional encoding part.
  • qk_rope_head_dim (int) – Head dimension for rope part.
  • v_head_dim (int) – Head dimension for value.
  • buffer_size (int) – Buffer size for storing the temporal results during prefill, in unit of tokens.

rope

rope: RotaryEmbedding

shard()

shard(devices)

Creates sharded views of this Module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded LatentAttentionWithRope instances, one for each device.

Return type:

list[LatentAttentionWithRope]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the Module sharding strategy.

w_uk_uv

property w_uk_uv: list[TensorValue]

The concatenation of q, k, and v weight vectors.

Was this page helpful?