Skip to main content

Python module

sampling

rejection_sampler()

max.pipelines.lib.sampling.rejection_sampler(device, *, seed=0)

Parameters:

Return type:

Graph

rejection_sampler_with_residuals()

max.pipelines.lib.sampling.rejection_sampler_with_residuals(device, *, seed=0, debug=False)

Rejection sampler with residual sampling for speculative decoding.

Computes acceptance ratios for draft tokens, finds first rejection, samples from residual distribution (target - draft), and generates bonus tokens.

Parameters:

Return type:

Graph

token_sampler()

max.pipelines.lib.sampling.token_sampler(sampling_config, device, return_logits=False)

Parameters:

  • sampling_config (SamplingConfig )
  • device (DeviceRef )
  • return_logits (bool )

Return type:

Graph