Python module
hf
ContinuousHFStaticCache
class max.nn.kv_cache.hf.ContinuousHFStaticCache(config, max_batch_size, max_seq_len, device, dtype=torch.float32, layer_device_map=None)
-
Parameters:
external_claim()
external_claim(seq_ids)
get_attention_mask()
get_attention_mask(seq_ids)
release()
release(seq_id)
-
Parameters:
-
seq_id (
int
) -
Return type:
-
None
reset()
reset()
Resets the cache values while preserving the objects
-
Return type:
-
None
set_active_slots()
set_active_slots(seq_ids)
set_cache_position()
set_cache_position(cache_position)
-
Parameters:
-
cache_position (
Tensor
)
update()
update(key_states, value_states, layer_idx, cache_kwargs=None)
Updates the cache with the new key_states and value_states for the layer layer_idx. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
-
Parameters:
-
- key_states (torch.Tensor) – The new key states to cache.
- value_states (torch.Tensor) – The new value states to cache.
- layer_idx (int) – The index of the layer to cache the states for.
- cache_kwargs (Dict[str, Any], optional) – Additional arguments for the cache subclass. The StaticCache needs the cache_position input to know how where to write in the cache.
-
Returns:
-
A tuple containing the updated key and value states.
-
Return type:
-
tuple[Tensor, Tensor]
update_attention_pattern()
update_attention_pattern(seq_id, attention_mask)
-
Parameters:
-
- seq_id (
int
) - attention_mask (
Tensor
)
- seq_id (
-
Return type:
-
None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!