Mojo function
flare_mla_decoding
flare_mla_decoding[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(dtype, UInt(Int.__init__[IntTuple](q_layout.shape[(rank - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(rank - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), OptionalReg[UInt](None), TensorMapSwizzle(3)), ragged: Bool = False, decoding_warp_split_k: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutableAnyOrigin]] = None, num_partitions: OptionalReg[Int] = None)
MLA decoding kernel that would only be called in the optimized compute graph.
The Q input has a shape of [seq_len, num_heads, depth]. The K input has a shape of [seq_len, 1, depth]. The V tensor is derived by reusing K, where V = K[:, :, :depth_v].
Specifically, for DeepSeek V2/3, depth = 576 and depth_v = 512.
This kernel computes attention without needing to load V twice. This kernel only handles decoding requests. In this case q_max_seq_len = 1.
This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.
flare_mla_decoding[mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(dtype, UInt(Int.__init__[IntTuple](q_layout.shape[2])), UInt(Int.__init__[IntTuple](q_layout.shape[3])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), OptionalReg[UInt](None), TensorMapSwizzle(3)), decoding_warp_split_k: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: OptionalReg[Int] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!