Skip to main content

Python module

cache_params

KVCacheParams

class max.nn.kv_cache.cache_params.KVCacheParams(dtype, n_kv_heads, head_dim, enable_prefix_caching=False, enable_kvcache_swapping_to_host=False, host_kvcache_swap_space_gb=None, cache_strategy=KVCacheStrategy.PAGED, page_size=None, n_devices=1, is_mla=False, data_parallel_degree=1, n_kv_heads_per_device=0)

Configuration parameters for key-value cache management in transformer models.

This class encapsulates all configuration options for managing KV caches during inference, including parallelism settings, memory management, and cache strategy.

Parameters:

  • dtype (DType)
  • n_kv_heads (int)
  • head_dim (int)
  • enable_prefix_caching (bool)
  • enable_kvcache_swapping_to_host (bool)
  • host_kvcache_swap_space_gb (float | None)
  • cache_strategy (KVCacheStrategy)
  • page_size (int | None)
  • n_devices (int)
  • is_mla (bool)
  • data_parallel_degree (int)
  • n_kv_heads_per_device (int)

cache_strategy

cache_strategy: KVCacheStrategy = 'paged'

Strategy to use for managing the KV cache.

copy_as_dp_1()

copy_as_dp_1()

Creates a copy of the KVCacheParams with data parallelism disabled.

This method creates a new instance of the current configuration and adjusts the device count to reflect a tensor-parallel-only setup (data_parallel_degree=1). The number of devices is divided by the current data parallel degree.

Returns:

A new KVCacheParams instance with data_parallel_degree set to 1.

Raises:

ValueError – If n_devices is not evenly divisible by data_parallel_degree.

Return type:

KVCacheParams

data_parallel_degree

data_parallel_degree: int = 1

Degree of data parallelism. Must be 1 or equal to n_devices (DP+TP not yet supported).

dtype

dtype: DType

Data type for storing key and value tensors in the cache.

dtype_shorthand

property dtype_shorthand: str

Returns a shorthand textual representation of the data type.

Returns:

“bf16” for bfloat16 dtype, “f32” otherwise.

enable_kvcache_swapping_to_host

enable_kvcache_swapping_to_host: bool = False

Whether to enable swapping of KV cache blocks to host memory when device memory is full.

enable_prefix_caching

enable_prefix_caching: bool = False

Whether to enable prefix caching for efficient reuse of common prompt prefixes.

head_dim

head_dim: int

Dimensionality of each attention head.

host_kvcache_swap_space_gb

host_kvcache_swap_space_gb: float | None = None

Amount of host memory (in GB) to reserve for KV cache swapping. Required when swapping is enabled.

is_mla

is_mla: bool = False

Whether the model uses Multi-Latent Attention (MLA) architecture.

n_devices

n_devices: int = 1

Total number of devices (GPUs/accelerators) available for inference.

n_kv_heads

n_kv_heads: int

Total number of key-value attention heads across all devices.

n_kv_heads_per_device

n_kv_heads_per_device: int = 0

Number of KV heads allocated to each device. Computed automatically in __post_init__.

page_size

page_size: int | None = None

Size of each page in the paged cache strategy. Required for paged caching.

static_cache_shape

property static_cache_shape: tuple[str, str, str, str, str]

Returns the dimension names for the static cache tensor shape.

Returns:

(num_layers, batch_size, seq_len, n_kv_heads, head_dim).

Return type:

A tuple of dimension names

KVCacheStrategy

class max.nn.kv_cache.cache_params.KVCacheStrategy(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

Enumeration of supported KV cache strategies for attention mechanisms.

This enum defines the different strategies for managing key-value caches in transformer models during inference.

MODEL_DEFAULT

MODEL_DEFAULT = 'model_default'

Use the model’s default caching strategy.

PAGED

PAGED = 'paged'

Use paged attention for efficient memory management.

kernel_substring()

kernel_substring()

Returns the common substring included in the kernel name for this caching strategy.

Returns:

The string representation of the cache strategy value.

Return type:

str

uses_opaque()

uses_opaque()

Determines if this cache strategy uses opaque cache implementations.

Returns:

True if the strategy uses opaque caching, False otherwise.

Return type:

bool

Was this page helpful?