Skip to main content
Log in

Mojo function

topk_wrapper

topk_wrapper[T: DType, out_idx_type: DType, is_top_p: Bool, largest: Bool = True, _test_sort: Bool = False](K: Int, num_elements: Int, num_blocks_per_input: Int, in_buffer: UnsafePointer[SIMD[T, 1]], local_topk_vals: UnsafePointer[SIMD[T, 1]], local_topk_idxs: UnsafePointer[SIMD[out_idx_type, 1]], p_threshold: UnsafePointer[SIMD[T, 1]], skip_sort: UnsafePointer[SIMD[bool, 1]])

Copy of Kernels/mojo/nn/topk.mojo:_topk_stage1 with the addition of max_vals and p_threshold arguments to determine if sorting is needed for top-p/min-p sampling.

Arguments: K: Int - Number of top elements to select per block num_elements: Int - Size of last dimension of input buffer (vocab size) num_blocks_per_input: Int - Number of blocks used to process the input data in_buffer: UnsafePointer[Scalar[T]] - Input buffer containing the elements to process local_topk_vals: UnsafePointer[Scalar[T]] - Output buffer to store the local top-K values local_topk_idxs: UnsafePointer[Scalar[out_idx_type]] - Output buffer to store the indices of local top-K elements p_threshold: UnsafePointer[Scalar[T]] - Threshold for top-p sampling if is_top_p is True else min-p cofficient skip_sort: UnsafePointer[Scalar[DType.bool]] - Output buffer to store whether sorting is needed

Parameters:

  • T (DType): DType - The data type of the elements.
  • out_idx_type (DType): DType - The data type of the output indices.
  • is_top_p (Bool): Bool - Whether this if for top-p sampling or min-p sampling.
  • largest (Bool): Bool - Whether to find the maximum or minimum value.
  • _test_sort (Bool): Bool - An internal test flag to not skip sort if testing.