Mojo function
topk_wrapper
topk_wrapper[T: DType, out_idx_type: DType, is_top_p: Bool, largest: Bool = True, _test_sort: Bool = False](K: Int, num_elements: Int, num_blocks_per_input: Int, in_buffer: UnsafePointer[SIMD[T, 1]], local_topk_vals: UnsafePointer[SIMD[T, 1]], local_topk_idxs: UnsafePointer[SIMD[out_idx_type, 1]], p_threshold: UnsafePointer[SIMD[T, 1]], skip_sort: UnsafePointer[SIMD[bool, 1]])
Copy of Kernels/mojo/nn/topk.mojo:_topk_stage1
with the addition of max_vals and p_threshold arguments to determine if sorting is needed for top-p/min-p sampling.
Arguments: K: Int - Number of top elements to select per block num_elements: Int - Size of last dimension of input buffer (vocab size) num_blocks_per_input: Int - Number of blocks used to process the input data in_buffer: UnsafePointer[Scalar[T]] - Input buffer containing the elements to process local_topk_vals: UnsafePointer[Scalar[T]] - Output buffer to store the local top-K values local_topk_idxs: UnsafePointer[Scalar[out_idx_type]] - Output buffer to store the indices of local top-K elements p_threshold: UnsafePointer[Scalar[T]] - Threshold for top-p sampling if is_top_p is True else min-p cofficient skip_sort: UnsafePointer[Scalar[DType.bool]] - Output buffer to store whether sorting is needed
Parameters:
- T (
DType
): DType - The data type of the elements. - out_idx_type (
DType
): DType - The data type of the output indices. - is_top_p (
Bool
): Bool - Whether this if for top-p sampling or min-p sampling. - largest (
Bool
): Bool - Whether to find the maximum or minimum value. - _test_sort (
Bool
): Bool - An internal test flag to not skip sort if testing.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!