Python module
log_probabilities
compute_log_probabilities_ragged()
max.pipelines.lib.log_probabilities.compute_log_probabilities_ragged(device, model, *, input_row_offsets, logits, next_token_logits, tokens, sampled_tokens, batch_top_n, batch_echo)
Computes the log probabilities for ragged model outputs.
-
Parameters:
-
- device (
Device
) – Device on which to do the bulk of the log probabilities computation. A small amount of computation still occurs on the host regardless of this setting. - model (
Model
) – A compiled version of a graph from the ‘log_probabilities_ragged_graph’ function. - input_row_offsets (
ndarray
) – Token offsets into token-indexed buffers, by batch index. Should have 1 more element than there are batches (batch n is token indices [input_row_offsets[n], input_row_offsets[n+1])). - logits (
Tensor
|
None
) – (tokens, vocab_dim) tensor full of tensor logits. Token dimension mapped to batches using input_row_offsets. May be omitted only if all ‘batch_echo’ values are False. - next_token_logits (
Tensor
) – (batch_dim, vocab_dim) tensor full of tensor logits for the next token in each batch item. - sampled_tokens (
ndarray
) – (batch_dim,) tensor of sampled token per batch - batch_top_n (
Sequence
[
int
]
) – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped. - batch_echo (
Sequence
[
bool
]
) – Whether to include input tokens in the returned log probabilities. - tokens (
ndarray
)
- device (
-
Returns:
-
Computed log probabilities for each item in the batch.
-
Return type:
-
list[LogProbabilities | None]
log_probabilities_ragged_graph()
max.pipelines.lib.log_probabilities.log_probabilities_ragged_graph(device, *, levels)
Create a graph to compute log probabilities over ragged inputs.
A model obtained by this graph is a required input to ‘compute_log_probabilities_ragged’.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!