Skip to main content
Log in

Python module

pipeline

HF Token Generation Pipeline

KVCacheMixin

class max.pipelines.pipeline.KVCacheMixin(*args, **kwargs)

estimate_kv_cache_size()

abstract classmethod estimate_kv_cache_size(pipeline_config: PipelineConfig, available_cache_memory: int, devices: list[max.driver.driver.Device]) → int

Estimates the size of the kv cache in bytes.

load_kv_manager()

load_kv_manager(session: InferenceSession, available_cache_memory: int | None) → KVCacheManager

Provided a PipelineConfig and InferenceSession, loads the KV manager.

  • Parameters:

    • session – Inference session to compile and init the KV cache.
    • available_cache_memory – Amount of memory available to the KV cache, in bytes.
  • Returns:

    one per input modality.

  • Return type:

    Either a single KV cache manager or a tuple of KV cache managers

ModelInputs

class max.pipelines.pipeline.ModelInputs

Base class for model inputs. Use this class to encapsulate inputs for your model. You may store any number of dataclass fields

Example

>>> class ReplitInputs(ModelInputs):
... tokens: Tensor
... input_row_offsets: Tensor
...
... def __init__(self, tokens: Tensor, input_row_offsets: Tensor):
... self.tokens = tokens
... self.input_row_offsets = input_row_offsets
...
>>> # Create tensors
>>> tokens = Tensor.zeros((1, 2, 3), DType.int64)
>>> input_row_offsets = Tensor.zeros((1, 1, 1), DType.int64)
>>> # Initialize inputs
>>> inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets)
>>> # Access tensors
>>> list(inputs) == [tokens, input_row_offsets]
True
>>> class ReplitInputs(ModelInputs):
... tokens: Tensor
... input_row_offsets: Tensor
...
... def __init__(self, tokens: Tensor, input_row_offsets: Tensor):
... self.tokens = tokens
... self.input_row_offsets = input_row_offsets
...
>>> # Create tensors
>>> tokens = Tensor.zeros((1, 2, 3), DType.int64)
>>> input_row_offsets = Tensor.zeros((1, 1, 1), DType.int64)
>>> # Initialize inputs
>>> inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets)
>>> # Access tensors
>>> list(inputs) == [tokens, input_row_offsets]
True

ModelOutputs

class max.pipelines.pipeline.ModelOutputs(next_token_logits: 'Tensor | None' = None, logits: 'Tensor | None' = None)

logits

logits*: Tensor | None* = None

Logits for the entire token sequence.

next_token_logits

next_token_logits*: Tensor | None* = None

Logits for just the next token.

PipelineModel

class max.pipelines.pipeline.PipelineModel(pipeline_config: PipelineConfig, session: InferenceSession)

A pipeline model with setup, input preparation and execution methods.

calculate_max_seq_len()

abstract classmethod calculate_max_seq_len(pipeline_config: PipelineConfig) → int

Calculate the optimal max sequence length for the model. Models are expected to implement this method.

Example

>>> class MistralModel(PipelineModel):
... @classmethod
... def calculate_max_seq_len(cls, pipeline_config: PipelineConfig) -> int:
... try:
... return upper_bounded_default(
... upper_bound=pipeline_config.huggingface_config.max_seq_len,
... default=pipeline_config.max_length,
... )
... except ValueError as e:
... msg = (
... "Unable to infer max_length for Mistral, the provided "
... f"max_length ({pipeline_config.max_length}) exceeds the "
... f"model's max_seq_len "
... f"({pipeline_config.huggingface_config.max_seq_len})."
... )
... raise ValueError(msg) from e
...
>>> class MistralModel(PipelineModel):
... @classmethod
... def calculate_max_seq_len(cls, pipeline_config: PipelineConfig) -> int:
... try:
... return upper_bounded_default(
... upper_bound=pipeline_config.huggingface_config.max_seq_len,
... default=pipeline_config.max_length,
... )
... except ValueError as e:
... msg = (
... "Unable to infer max_length for Mistral, the provided "
... f"max_length ({pipeline_config.max_length}) exceeds the "
... f"model's max_seq_len "
... f"({pipeline_config.huggingface_config.max_seq_len})."
... )
... raise ValueError(msg) from e
...

compute_log_probabilities()

compute_log_probabilities(model_inputs: ModelInputs, model_outputs: ModelOutputs, next_tokens: Tensor, batch_top_n: list[int], batch_echo: list[bool]) → list[max.pipelines.response.LogProbabilities | None] | None

Optional method that can be overridden to compute log probabilities.

  • Parameters:

    • model_inputs – Inputs to the model returned by prepare_*_token_inputs().
    • model_outputs – Outputs returned by execute().
    • next_tokens – Sampled tokens. Should have shape=[batch size]
    • batch_top_n – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
    • batch_echo – Whether to include input tokens in the returned log probabilities.
  • Returns:

    List of log probabilities.

estimate_weights_size()

classmethod estimate_weights_size(pipeline_config: PipelineConfig) → int

Calculates the estimated memory consumption of our model.

execute()

abstract execute(model_inputs: ModelInputs, kv_cache_inputs: Sequence[Tensor] | None = None) → ModelOutputs

Executes the graph with the given inputs.

  • Parameters:

    • model_inputs – The model inputs to execute, containing tensors and any other required data for model execution.
    • kv_cache_inputs – The kv cache inputs to execute, containing tensors and any other required data for model execution.
  • Returns:

    ModelOutputs containing the pipeline’s output tensors.

This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic.

get_kv_params()

abstract classmethod get_kv_params(pipeline_config: PipelineConfig) → KVCacheParams

Returns the KV cache params for the pipeline model.

get_num_layers()

abstract classmethod get_num_layers(pipeline_config: PipelineConfig) → int

Returns the number of layers for the pipeline model.

infer_optimal_batch_size()

classmethod infer_optimal_batch_size(pipeline_config: PipelineConfig, available_cache_memory: int) → int

Returns the estimated optimal batch size to run the model given current memory constraints.

prepare_initial_token_inputs()

abstract prepare_initial_token_inputs(context_batch: Sequence[T]) → ModelInputs

Prepares the initial inputs to be passed to .execute().

The inputs and functionality of this method can vary per model. For example, the model inputs could include:

  • Encoded tensors
  • A unique IDs for each tensor if this model uses a KV Cache manager.

This function would batch the encoded tensors, claim a slot in the kv cache if the ID hasn’t been seen before, and return the inputs and caches as a list of tensors.

prepare_next_token_inputs()

abstract prepare_next_token_inputs(next_tokens: Tensor, prev_model_inputs: ModelInputs) → ModelInputs

Prepares the secondary inputs to be passed to .execute().

While prepare_initial_token_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern.

TextGenerationPipeline

class max.pipelines.pipeline.TextGenerationPipeline(pipeline_config: PipelineConfig, pipeline_model: Type[PipelineModel], eos_token_id: int)

Generalized token generator pipeline.

calculate_num_steps()

calculate_num_steps(num_steps: int, context: T) → int

next_token()

next_token(batch: dict[str, T], num_steps: int) → list[dict[str, Any]]

Provided a batch, process batch inputs, execute the graph for num_steps in a multi-step scenario, then decode the tokens holistically and return the list of decoded tokens.

prepare_batch()

prepare_batch(batch: list[T], num_steps: int) → tuple[max.pipelines.pipeline.ModelInputs, Any, int]

release()

release(context: T) → None

Mark the context as complete, releasing the cache slot from the KV manager.

upper_bounded_default()

max.pipelines.pipeline.upper_bounded_default(upper_bound: int, default: int | None) → int

Given an upper bound and an optional default value, returns a final value that cannot exceed the upper bound.

  • Parameters:

    • default – The default value to use, or None to use the upper bound.
    • upper_bound – The upper bound to use.
  • Raises:

    ValueError – If the provided default value exceeds the upper bound.

  • Returns:

    The final value.