Skip to main content

Python module

hf

ContinuousHFStaticCache

class max.nn.kv_cache.hf.ContinuousHFStaticCache(config, max_batch_size, max_seq_len, device, dtype=torch.float32, layer_device_map=None)

Parameters:

  • config (PretrainedConfig )
  • max_batch_size (int )
  • max_seq_len (int )
  • device (device )
  • dtype (dtype )
  • layer_device_map (dict [ int , str | device | int ] | None )

external_claim()

external_claim(seq_ids)

Parameters:

seq_ids (list [ int ] )

Return type:

None

get_attention_mask()

get_attention_mask(seq_ids)

Parameters:

seq_ids (list [ int ] )

Return type:

Tensor

release()

release(seq_id)

Parameters:

seq_id (int )

Return type:

None

reset()

reset()

Resets the cache values while preserving the objects

Return type:

None

set_active_slots()

set_active_slots(seq_ids)

Parameters:

seq_ids (list [ int ] )

Return type:

None

set_cache_position()

set_cache_position(cache_position)

Parameters:

cache_position (Tensor )

update()

update(key_states, value_states, layer_idx, cache_kwargs=None)

Updates the cache with the new key_states and value_states for the layer layer_idx. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.

Parameters:

  • key_states (torch.Tensor) – The new key states to cache.
  • value_states (torch.Tensor) – The new value states to cache.
  • layer_idx (int) – The index of the layer to cache the states for.
  • cache_kwargs (Dict[str, Any], optional) – Additional arguments for the cache subclass. The StaticCache needs the cache_position input to know how where to write in the cache.

Returns:

A tuple containing the updated key and value states.

Return type:

tuple[Tensor, Tensor]

update_attention_pattern()

update_attention_pattern(seq_id, attention_mask)

Parameters:

  • seq_id (int )
  • attention_mask (Tensor )

Return type:

None