Python module
manager
Abstract base class for KVCacheManager for KV Cache.
KVCacheInputSymbols
class max.pipelines.kv_cache.manager.KVCacheInputSymbols
Base class for input symbols for KV cache managers.
The derived class is responsible for defining the input symbols for the specific KV cache manager.
For example, here’s a derived class for a text KV cache manager: : ```pycon
@dataclass ... class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols): ... kv_blocks: TensorType ... cache_lengths: TensorType ... lookup_table: TensorType ... max_lengths: TensorType
## `KVCacheInputs` \{#max.pipelines.kv_cache.manager.KVCacheInputs}
> *class* max.pipelines.kv_cache.manager.KVCacheInputs
A base class that holds KV cache related (Tensor) inputs.
It is meant to be subclassed by concrete KV cache input types.
### Example
```pycon
>>> @dataclass
... class RaggedKVCacheInputs(KVCacheInputs):
... blocks: Tensor
... cache_lengths: Tensor
... lookup_table: Tensor
... max_lengths: Tensor
## `KVCacheInputs` \{#max.pipelines.kv_cache.manager.KVCacheInputs}
> *class* max.pipelines.kv_cache.manager.KVCacheInputs
A base class that holds KV cache related (Tensor) inputs.
It is meant to be subclassed by concrete KV cache input types.
### Example
```pycon
>>> @dataclass
... class RaggedKVCacheInputs(KVCacheInputs):
... blocks: Tensor
... cache_lengths: Tensor
... lookup_table: Tensor
... max_lengths: Tensor
KVCacheInputsSequence
class max.pipelines.kv_cache.manager.KVCacheInputsSequence(kv_cache_inputs: Sequence[KVCacheInputs])
KVCacheInputsSequence is a sequence of KVCacheInputs. It is primarily used in our multistep execution to represent batched KVCacheInputs.
kv_cache_inputs
kv_cache_inputs*: Sequence[KVCacheInputs]*
KVCacheManager
class max.pipelines.kv_cache.manager.KVCacheManager(params: KVCacheParams, max_batch_size: int, max_seq_len: int, num_layers: int, devices: List[Device], session: InferenceSession, is_ragged: bool = False)
claim()
Claims n blocks of memory in the cache for incoming requests.
This returns a list of sequence ids, which identify a sequence’s location within the cache. This sequence id can then be passed in the fetch function to return the ContinuousBatchingKVCacheCollection for those sequences.
contains()
estimated_memory_size()
abstract classmethod estimated_memory_size(params: KVCacheParams, max_batch_size: int, max_seq_len: int, num_layers: int, available_cache_memory: int, devices: List[Device], **kwargs: Any) → int
Returns the estimated total memory usage of the kv cache.
external_claim()
Variant of the above where sequence ids are reserved externally.
fetch()
final fetch(seq_ids_and_prompts: dict[int, numpy.ndarray], num_steps: int = 1) → List[KVCacheInputs]
Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts.
increment_cache_lengths()
increment_cache_lengths(kv_cache_inputs: List[RaggedKVCacheInputs] | List[PaddedKVCacheInputs], prev_model_inputs: Any) → List[RaggedKVCacheInputs] | List[PaddedKVCacheInputs]
Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.
This should also not update the cache lengths in our manager, this batch is still considered in-progress.
infer_optimal_batch_size()
abstract classmethod infer_optimal_batch_size(params: KVCacheParams, max_seq_len: int, num_layers: int, available_cache_memory: int, devices: List[Device], **kwargs: Any) → int
Returns the estimated optimal batch size for the kv cache.
input_symbols()
abstract input_symbols() → Sequence[KVCacheInputSymbols]
Returns the input symbols for the kv cache manager.
max_sequence_length
property max_sequence_length*: int*
The maximum sequence length in current cache.
num_kv_inputs()
num_kv_inputs() → int
Returns the default number of KV cache inputs for KV managers.
Subclasses with a different number of KV cache inputs should override this method and increment_cache_lengths.
release()
Release seq_id provided, marking this sequence as complete. This returns the seq_id back to the available pool of cache memory, allowing it to be reused when a new sequence is claimed.
slots_remaining
The outstanding cache slots available.
step()
step(seq_ids_and_new_tokens: dict[int, numpy.ndarray]) → None
Update the cache_lengths objects to note that a new kv projection step has occurred, and that the underlying memory has been written to. This cache_lengths value is then used downstream in fetch to track what section of memory should be used in the kernels.
PaddedKVCacheInputs
class max.pipelines.kv_cache.manager.PaddedKVCacheInputs(k_cache: Tensor, v_cache: Tensor, start_pos: Tensor, null_op: Tensor)
PaddedKVCacheInputs is a class that holds the inputs for KV cache when used together with padded tensors.
k_cache
k_cache*: Tensor*
null_op
null_op*: Tensor*
start_pos
start_pos*: Tensor*
v_cache
v_cache*: Tensor*
RaggedKVCacheInputs
class max.pipelines.kv_cache.manager.RaggedKVCacheInputs(blocks: Tensor, cache_lengths: Tensor, lookup_table: Tensor, max_lengths: Tensor)
RaggedKVCacheInputs is a class that holds the inputs for KV cache when used together with ragged tensors.
blocks
blocks*: Tensor*
cache_lengths
cache_lengths*: Tensor*
lookup_table
lookup_table*: Tensor*
max_lengths
max_lengths*: Tensor*
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!