Skip to main content
Log in

Python module

manager

Abstract base class for KVCacheManager for KV Cache.

KVCacheInputSymbols

class max.pipelines.kv_cache.manager.KVCacheInputSymbols

Base class for input symbols for KV cache managers.

The derived class is responsible for defining the input symbols for the specific KV cache manager.

For example, here’s a derived class for a text KV cache manager: : ```pycon

@dataclass ... class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols): ... kv_blocks: TensorType ... cache_lengths: TensorType ... lookup_table: TensorType ... max_lengths: TensorType


## `KVCacheInputs` \{#max.pipelines.kv_cache.manager.KVCacheInputs}

> *class* max.pipelines.kv_cache.manager.KVCacheInputs

A base class that holds KV cache related (Tensor) inputs.

It is meant to be subclassed by concrete KV cache input types.

### Example

```pycon
>>> @dataclass
... class RaggedKVCacheInputs(KVCacheInputs):
... blocks: Tensor
... cache_lengths: Tensor
... lookup_table: Tensor
... max_lengths: Tensor

## `KVCacheInputs` \{#max.pipelines.kv_cache.manager.KVCacheInputs}

> *class* max.pipelines.kv_cache.manager.KVCacheInputs

A base class that holds KV cache related (Tensor) inputs.

It is meant to be subclassed by concrete KV cache input types.

### Example

```pycon
>>> @dataclass
... class RaggedKVCacheInputs(KVCacheInputs):
... blocks: Tensor
... cache_lengths: Tensor
... lookup_table: Tensor
... max_lengths: Tensor

KVCacheInputsSequence

class max.pipelines.kv_cache.manager.KVCacheInputsSequence(kv_cache_inputs: Sequence[KVCacheInputs])

KVCacheInputsSequence is a sequence of KVCacheInputs. It is primarily used in our multistep execution to represent batched KVCacheInputs.

kv_cache_inputs

kv_cache_inputs*: Sequence[KVCacheInputs]*

KVCacheManager

class max.pipelines.kv_cache.manager.KVCacheManager(params: KVCacheParams, max_batch_size: int, max_seq_len: int, num_layers: int, devices: List[Device], session: InferenceSession, is_ragged: bool = False)

claim()

claim(n: int) → List[int]

Claims n blocks of memory in the cache for incoming requests.

This returns a list of sequence ids, which identify a sequence’s location within the cache. This sequence id can then be passed in the fetch function to return the ContinuousBatchingKVCacheCollection for those sequences.

contains()

contains(seq_id: int) → bool

estimated_memory_size()

abstract classmethod estimated_memory_size(params: KVCacheParams, max_batch_size: int, max_seq_len: int, num_layers: int, available_cache_memory: int, devices: List[Device], **kwargs: Any) → int

Returns the estimated total memory usage of the kv cache.

external_claim()

external_claim(seq_ids: List[int]) → None

Variant of the above where sequence ids are reserved externally.

fetch()

final fetch(seq_ids_and_prompts: dict[int, numpy.ndarray], num_steps: int = 1) → List[KVCacheInputs]

Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts.

increment_cache_lengths()

increment_cache_lengths(kv_cache_inputs: List[RaggedKVCacheInputs] | List[PaddedKVCacheInputs], prev_model_inputs: Any) → List[RaggedKVCacheInputs] | List[PaddedKVCacheInputs]

Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.

This should also not update the cache lengths in our manager, this batch is still considered in-progress.

infer_optimal_batch_size()

abstract classmethod infer_optimal_batch_size(params: KVCacheParams, max_seq_len: int, num_layers: int, available_cache_memory: int, devices: List[Device], **kwargs: Any) → int

Returns the estimated optimal batch size for the kv cache.

input_symbols()

abstract input_symbols() → Sequence[KVCacheInputSymbols]

Returns the input symbols for the kv cache manager.

max_sequence_length

property max_sequence_length*: int*

The maximum sequence length in current cache.

num_kv_inputs()

num_kv_inputs() → int

Returns the default number of KV cache inputs for KV managers.

Subclasses with a different number of KV cache inputs should override this method and increment_cache_lengths.

release()

release(seq_id: int) → None

Release seq_id provided, marking this sequence as complete. This returns the seq_id back to the available pool of cache memory, allowing it to be reused when a new sequence is claimed.

slots_remaining

property slots_remaining*: set[int]*

The outstanding cache slots available.

step()

step(seq_ids_and_new_tokens: dict[int, numpy.ndarray]) → None

Update the cache_lengths objects to note that a new kv projection step has occurred, and that the underlying memory has been written to. This cache_lengths value is then used downstream in fetch to track what section of memory should be used in the kernels.

PaddedKVCacheInputs

class max.pipelines.kv_cache.manager.PaddedKVCacheInputs(k_cache: Tensor, v_cache: Tensor, start_pos: Tensor, null_op: Tensor)

PaddedKVCacheInputs is a class that holds the inputs for KV cache when used together with padded tensors.

k_cache

k_cache*: Tensor*

null_op

null_op*: Tensor*

start_pos

start_pos*: Tensor*

v_cache

v_cache*: Tensor*

RaggedKVCacheInputs

class max.pipelines.kv_cache.manager.RaggedKVCacheInputs(blocks: Tensor, cache_lengths: Tensor, lookup_table: Tensor, max_lengths: Tensor)

RaggedKVCacheInputs is a class that holds the inputs for KV cache when used together with ragged tensors.

blocks

blocks*: Tensor*

cache_lengths

cache_lengths*: Tensor*

lookup_table

lookup_table*: Tensor*

max_lengths

max_lengths*: Tensor*