Python module
cache_params
KVCacheParams
class max.nn.kv_cache.cache_params.KVCacheParams(dtype, n_kv_heads, head_dim, enable_prefix_caching=False, enable_kvcache_swapping_to_host=False, host_kvcache_swap_space_gb=None, cache_strategy=KVCacheStrategy.PAGED, page_size=None, n_devices=1, is_mla=False, data_parallel_degree=1, n_kv_heads_per_device=0)
Configuration parameters for key-value cache management in transformer models.
This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, memory management, and cache strategy.
-
Parameters:
cache_strategy
cache_strategy: KVCacheStrategy = 'paged'
Strategy to use for managing the KV cache.
copy_as_dp_1()
copy_as_dp_1()
Creates a copy of the KVCacheParams with data parallelism disabled.
This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.
-
Returns:
-
A new KVCacheParams instance with data_parallel_degree set to 1.
-
Raises:
-
ValueError – If n_devices is not evenly divisible by data_parallel_degree.
-
Return type:
data_parallel_degree
data_parallel_degree: int = 1
Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).
dtype
dtype: DType
Data type for storing key and value tensors in the cache.
dtype_shorthand
property dtype_shorthand: str
Returns a shorthand textual representation of the data type.
-
Returns:
-
“bf16” for bfloat16 dtype, “f32” otherwise.
enable_kvcache_swapping_to_host
enable_kvcache_swapping_to_host: bool = False
Whether to enable swapping of KV cache blocks to host memory when device memory is full.
enable_prefix_caching
enable_prefix_caching: bool = False
Whether to enable prefix caching for efficient reuse of common prompt prefixes.
head_dim
head_dim: int
Dimensionality of each attention head.
host_kvcache_swap_space_gb
Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled.
is_mla
is_mla: bool = False
Whether the model uses Multi-Latent Attention (MLA) architecture.
n_devices
n_devices: int = 1
Total number of devices (GPUs/accelerators) available for inference.
n_kv_heads
n_kv_heads: int
Total number of key-value attention heads across all devices.
n_kv_heads_per_device
n_kv_heads_per_device: int = 0
Number of KV heads allocated to each device. Computed automatically in __post_init__.
page_size
Size of each page in the paged cache strategy. Required for paged caching.
static_cache_shape
Returns the dimension names for the static cache tensor shape.
-
Returns:
-
(num_layers, batch_size, seq_len, n_kv_heads, head_dim).
-
Return type:
-
A tuple of dimension names
KVCacheStrategy
class max.nn.kv_cache.cache_params.KVCacheStrategy(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Enumeration of supported KV cache strategies for attention mechanisms.
This enum defines the different strategies for managing key-value caches in transformer models during inference.
MODEL_DEFAULT
MODEL_DEFAULT = 'model_default'
Use the model’s default caching strategy.
PAGED
PAGED = 'paged'
Use paged attention for efficient memory management.
kernel_substring()
kernel_substring()
Returns the common substring included in the kernel name for this caching strategy.
-
Returns:
-
The string representation of the cache strategy value.
-
Return type:
uses_opaque()
uses_opaque()
Determines if this cache strategy uses opaque cache implementations.
-
Returns:
-
True if the strategy uses opaque caching, False otherwise.
-
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!