Skip to main content

Mojo function

flash_attention_ragged

flash_attention_ragged[mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(type, UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), OptionalReg[UInt](None), TensorMapSwizzle(3)), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[type, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: ManagedTensorSlice[IOSpec[True, IO(-1)](), static_spec=StaticTensorSpec.create_unknown[DType.uint32, 1]()], max_prompt_len: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: OptionalReg[Int] = None)

Was this page helpful?