Mojo function
mha_sm100_dispatch
mha_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, ScoreModType: ScoreModTrait, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig, group: Int, use_score_mod: Bool, ragged: Bool, sink: Bool, _is_cache_length_accurate: Bool](output: UnsafePointer[Scalar[output_type]], q_arg: UnsafePointer[Scalar[q_type]], k: KVType, v: KVType, num_rows_q: Int, mask: MaskType, score_mod: ScoreModType, valid_length: UnsafePointer[UInt32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[NDBuffer[DType.uint32, 1, MutableAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, sink_weights: OptionalReg[NDBuffer[q_type, 1, MutableAnyOrigin]])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!