Skip to main content

Mojo function

flare_mla_prefill_dispatch

flare_mla_prefill_dispatch[rank: Int, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, output_type: DType, softmax_type: DType, q_layout: Layout, //, kv_num_heads: Int, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False, q_depth: Int = 192, cache_depth: Int = 576, config: MHAConfig = MHAConfig(dtype, UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), OptionalReg[UInt](None), TensorMapSwizzle(3)), _ndbuffer_mha_operand: Bool = False](output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: k_t, v: v_t, k_rope: k_rope_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_prompt_len: Int, scale: Float32, ctx: DeviceContext, softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutableAnyOrigin]] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutableAnyOrigin]] = None, prev_output: OptionalReg[LayoutTensor[output_type, Layout.row_major[rank](), MutableAnyOrigin]] = None, prev_softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutableAnyOrigin]] = None)

Was this page helpful?