Python module
kernels
Helper functions for wrapping custom kv cache/attention related ops.
AttentionMaskVariant
class max.nn.kernels.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_MASK
CAUSAL_MASK = 'causal_mask'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = 'chunked_causal_mask'
NULL_MASK
NULL_MASK = 'null_mask'
SLIDING_WINDOW_MASK
SLIDING_WINDOW_MASK = 'sliding_window_mask'
TENSOR_MASK
TENSOR_MASK = 'tensor_mask'
MHAMaskConfig
class max.nn.kernels.MHAMaskConfig(attention_mask_variant: 'AttentionMaskVariant', positional_encoding_variant: 'PositionalEncodingVariant')
attention_mask_variant
attention_mask_variant*: AttentionMaskVariant*
positional_encoding_variant
positional_encoding_variant*: PositionalEncodingVariant*
MHAMaskVariant
class max.nn.kernels.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_ALIBI_MASK
CAUSAL_ALIBI_MASK = '1'
CAUSAL_MASK
CAUSAL_MASK = '0'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = '3'
NULL_MASK
NULL_MASK = '2'
SLIDING_WINDOW_MASK
SLIDING_WINDOW_MASK = '4'
PositionalEncodingVariant
class max.nn.kernels.PositionalEncodingVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALIBI_POS
ALIBI_POS = 'alibi_pos'
NO_POS
NO_POS = 'no_pos'
cross_attention_ragged()
max.nn.kernels.cross_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, kv_input_row_offsets: TensorValue, q_max_seq_len: TensorValue, scale: float) → TensorValue
Computes cross attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
attention, kv_input_row_offsets represents the KV sequence length.
flare_mla_decode_ragged()
max.nn.kernels.flare_mla_decode_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, scale: float, qk_rope_dim: int = 64) → TensorValue
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
flare_mla_decompress_k_cache()
max.nn.kernels.flare_mla_decompress_k_cache(kv_params: KVCacheParams, buffer_row_offsets_1d: TensorValue, cache_offsets_1d: TensorValue, buffer_length: TensorValue, weight: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, buffer_size: int) → TensorValue
This kernel decompresses the key cache by up-projecting latent representations into the KV space using a weight matrix.
The process involves: : 1. Copying buffer_length latent vectors from the key cache into a contiguous buffer (k_latent) 2. Computing k = k_latent @ weight.T to obtain the decompressed keys
-
Returns:
A tensor of shape [buffer_size, weight.shape[0]] containing the decompressed keys. Note that only the first buffer_length tokens are valid.
flare_mla_prefill_plan()
max.nn.kernels.flare_mla_prefill_plan(kv_params: KVCacheParams, input_row_offsets: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, buffer_size: int, max_chunks: int = 16) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue]
This kernel plans how to process a batch of sequences with varying lengths using a fixed-size buffer.
Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_size.
For each chunk (iteration), it calculates: : 1. Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration
flare_mla_prefill_ragged()
max.nn.kernels.flare_mla_prefill_ragged(kv_params: KVCacheParams, input: TensorValue, k: TensorValue, v: TensorValue, input_row_offsets: TensorValue, buffer_row_offsets: TensorValue, cache_offsets: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: TensorValue, mask_variant: MHAMaskVariant, scale: float, qk_rope_dim: int = 64) → TensorValue
Performs MLA prefill.
flash_attention()
max.nn.kernels.flash_attention(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, attention_mask: TensorValue, valid_lengths: TensorValue, scale: float) → TensorValue
Computes flash attention provided the mo.opaque KV Cache.
flash_attention_ragged()
max.nn.kernels.flash_attention_ragged(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection | PagedKVCacheCollectionFA3Fallback, layer_idx: TensorValue, mask_variant: MHAMaskVariant, scale: float, context_lengths: TensorValue | None = None, local_window_size: int = 8192) → TensorValue
Computes flash (self) attention provided the !mo.opaque KV Cache.
Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross_attention_ragged.
flash_attention_ragged_paged_fa3_fallback()
max.nn.kernels.flash_attention_ragged_paged_fa3_fallback(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: PagedKVCacheCollectionFA3Fallback, context_lengths: TensorValue, layer_idx: TensorValue) → TensorValue
Computes flash attention provided the !mo.opaque KV Cache. using the FA3 fallback kernel.
flash_attention_with_causal_mask()
max.nn.kernels.flash_attention_with_causal_mask(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, valid_lengths: TensorValue, scale: float) → TensorValue
Computes flash attention provided the mo.opaque KV Cache. Notably, materializes the causal mask within the kernel.
fused_qk_ragged_rope()
max.nn.kernels.fused_qk_ragged_rope(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, freqs_cis: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue
Computes fused query-key attention with rotary positional encodings and ragged inputs.
-
Parameters:
- input – [batch_size * seq_len, n_heads, head_dim]
- input_row_offsets –
- freqs_cis – tensor of shape (max_seq_len * 2, head_dim)
- layer_idx –
- interleaved –
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
fused_qk_rope()
max.nn.kernels.fused_qk_rope(kv_params: KVCacheParams, input: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, freqs_cis_2d: TensorValue, layer_idx: TensorValue, interleaved: bool = True) → TensorValue
Computes fused query-key attention with rotary positional encodings.
fused_qkv_matmul()
max.nn.kernels.fused_qkv_matmul(kv_params: KVCacheParams, input: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: TensorValue, n_heads: int) → TensorValue
Computes fused query, key and value projections.
fused_qkv_ragged_matmul()
max.nn.kernels.fused_qkv_ragged_matmul(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
fused_qkv_ragged_matmul_quantized()
max.nn.kernels.fused_qkv_ragged_matmul_quantized(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, wqkv: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue, n_heads: int, quantization_config: QuantizationConfig, perm_idx: TensorValue | None = None, bias: TensorValue | None = None) → TensorValue
Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
grouped_matmul_ragged()
max.nn.kernels.grouped_matmul_ragged(hidden_states: TensorValue, weight: TensorValue, expert_start_indices: TensorValue, expert_ids: TensorValue, expert_usage_stats_host: TensorValue) → TensorValue
Grouped matmul used in MoE layer.
hidden_states and expert_start_indices are used together to implement the ragged tensor. expert_start_indices indicates where each group starts and ends in hidden_states
expert_ids is the id of the expert for each group in hidden_states
expert_usage_stats_host is the maximum number of tokens assigned to any expert, and the number of active experts.
kv_cache_get_max_seq_len()
max.nn.kernels.kv_cache_get_max_seq_len(kv_collection: PagedKVCacheCollection) → TensorValue
This kernel returns the maximum sequence length.
matmul_k_cache_ragged()
max.nn.kernels.matmul_k_cache_ragged(kv_params: KVCacheParams, hidden_states: TensorValue, input_row_offsets: TensorValue, weight: TensorValue, kv_collection: PagedKVCacheCollection, layer_idx: int | integer) → None
Computes key projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
matmul_kv_cache_ragged()
max.nn.kernels.matmul_kv_cache_ragged(kv_params: KVCacheParams, hidden_states: TensorValue, input_row_offsets: TensorValue, weight: TensorValue, kv_collection: ContinuousBatchingKVCacheCollection, layer_idx: int | integer) → None
Computes key and value projections with ragged input.
hidden_states and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
moe_create_indices()
max.nn.kernels.moe_create_indices(topk_ids: TensorValue, num_local_experts: int) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue, max.graph.value.TensorValue]
Creates indices for the MoE layer.
-
Parameters:
- topk_ids – The expert assignments for each token from the router.
- num_local_experts – The number of experts on this device.
-
Returns:
- token_expert_order: The reordered token indices, grouped by assigned expert.
- expert_start_indices: The starting index for each expert’s token group in : the reordered sequence.
- restore_token_order: The indices to restore original token ordering after : expert computation.
- expert_ids: ids of active experts selected for tokens
- expert_usage_stats: The maximum number of tokens assigned to any expert, : and the number of active experts.
-
Return type:
A tuple of four tensors
rms_norm_key_cache()
max.nn.kernels.rms_norm_key_cache(kv_params: KVCacheParams, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, gamma: TensorValue, epsilon: float | floating, layer_idx: int | integer, total_seq_len: Dim, input_row_offsets: TensorValue, rms_norm_cols: int | None = None) → None
Computes RMSNorm on the _new_ entries in the KVCache.
This function applies RMSNorm to either all dimensions or a subset of dimensions in each head of the key cache. The size of the gamma tensor determines how many dimensions will be normalized. If gamma’s size doesn’t match head_dim, rms_norm_cols must be explicitly specified to confirm the intention to normalize only a subset of dimensions.
Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input_row_offsets to do this bookkeeping.
swish_glu()
max.nn.kernels.swish_glu(a: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b0: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, b1: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
unfused_qkv_ragged_matmul_gguf_quantized()
max.nn.kernels.unfused_qkv_ragged_matmul_gguf_quantized(kv_params: KVCacheParams, input: TensorValue, input_row_offsets: TensorValue, n_heads: int, q_weight: TensorValue, k_weight: TensorValue, v_weight: TensorValue, quantization_encoding_q: QuantizationEncoding, quantization_encoding_k: QuantizationEncoding, quantization_encoding_v: QuantizationEncoding, kv_collection: ContinuousBatchingKVCacheCollection | PagedKVCacheCollection, layer_idx: TensorValue) → TensorValue
Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization_config must be provided.
input and input_row_offsets are used together to implement the ragged tensor. input_row_offsets indicates where each batch starts and ends in input
-
Raises:
ValueError – on input shapes/dtypes that are invalid for the kernel.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!