Mojo function
mha_decoding
mha_decoding[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = UInt(1), use_score_mod: Bool = False, ragged: Bool = False, is_shared_kv: Bool = False, sink: Bool = False, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False, decoding_warp_split_k: Bool = False](q_ptr: UnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: UnsafePointer[Scalar[output_type]], exp_sum_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, batch_size: Int, num_partitions: Int, max_cache_valid_length: Int, valid_length: NDBuffer[DType.uint32, 1, MutableAnyOrigin], sink_weights: OptionalReg[NDBuffer[q_type, 1, MutableAnyOrigin]], mask: mask_t, score_mod: score_mod_t)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!