Python module
manager
Abstract base class for KVCacheManager for KV Cache.
KVCacheManager
class max.pipelines.kv_cache.manager.KVCacheManager(params: KVCacheParams, max_cache_batch_size: int, max_seq_len: int, num_layers: int, devices: List[Device], session: InferenceSession, is_ragged: bool = False)
claim()
Claims n blocks of memory in the cache for incoming requests.
This returns a list of sequence ids, which identify a sequence’s location within the cache. This sequence id can then be passed in the fetch function to return the ContinuousBatchingKVCacheCollection for those sequences.
contains()
estimated_memory_size()
abstract classmethod estimated_memory_size(params: KVCacheParams, max_cache_batch_size: int, max_seq_len: int, num_layers: int, available_cache_memory: int, devices: List[Device]) → int
Returns the estimated total memory usage of the kv cache.
external_claim()
Variant of the above where sequence ids are reserved externally.
fetch()
abstract fetch(seq_ids_and_prompts: dict[int, numpy.ndarray], num_steps: int = 1) → List[tuple[max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor]]
increment_cache_lengths()
increment_cache_lengths(kv_cache_inputs: List[tuple[max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor]], prev_model_inputs: tuple[max.driver.tensor.Tensor, ...]) → List[tuple[max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor, max.driver.tensor.Tensor]]
Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution.
This should also not update the cache lengths in our manager, this batch is still considered in-progress.
input_symbols()
abstract input_symbols() → Sequence[tuple[max.graph.type.Type, max.graph.type.Type, max.graph.type.TensorType, max.graph.type.TensorType]]
max_sequence_length
property max_sequence_length*: int*
The maximum sequence length in current cache.
release()
Release seq_id provided, marking this sequence as complete. This returns the seq_id back to the available pool of cache memory, allowing it to be reused when a new sequence is claimed.
slots_remaining
The outstanding cache slots available.
step()
step(seq_ids_and_prompts: dict[int, numpy.ndarray], num_steps: int = 1) → None
Update the cache_lengths objects to not that a new kv projection step has occurred, and that the underlying memory has been written to. This cache_lengths value is then used downstream in fetch to track what section of memory should be used in the kernels.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!