Mojo function
mha_sm90_dispatch
mha_sm90_dispatch[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, output_type: DType, max_prompt_len_t: OptionallyStaticInt, partition_t: MHAPartitionScheme, //, config: MHAConfig, group: Int, use_score_mod: Bool, ragged: Bool, _is_cache_length_accurate: Bool](output: UnsafePointer[SIMD[output_type, 1]], q: UnsafePointer[SIMD[type, 1]], k: k_t, v: v_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: NDBuffer[uint32, 1, origin, shape, strides], max_prompt_len_arg: max_prompt_len_t, max_cache_valid_length_arg: Int, scale: SIMD[float32, 1], kv_input_row_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]], batch_size_arg: Int, partition: partition_t, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!