Skip to main content

Mojo function

produce

produce[qkv_type: DType, BM: Int, BN: Int, depth: Int, padded_depth: Int, num_heads: Int, group: Int, PartitionType: MHAPartitionScheme, swizzle_mode: TensorMapSwizzle, q_tma_rows: Int, q_tma_cols: Int, MaxSeqLenType: OptionallyStaticInt, SchedulerType: MHATileScheduler, KVLUTType: MHAOperand, MaskType: MHAMask, KVInputRowOffsetsType: OptionalPointer, ValidLengthType: OptionalPointer, //, *, pipeline_stages: Int, ragged: Bool, _is_cache_length_accurate: Bool](q_tma_op: TMATensorTile[qkv_type, tile_layout_k_major[qkv_type, q_tma_rows, q_tma_cols, swizzle_mode](), _tma_desc_tile_layout[qkv_type, 2, IndexList[2, DType.int64](q_tma_rows, q_tma_cols, Tuple[]()), swizzle_mode=swizzle_mode]()], k_tma_op: TMATensorTile[qkv_type, tile_layout_k_major[qkv_type, BN, padded_depth, swizzle_mode](), _tma_desc_tile_layout[qkv_type, 2, IndexList[2, DType.int64](BN, padded_depth, Tuple[]()), swizzle_mode=swizzle_mode]()], v_tma_op: TMATensorTile[qkv_type, tile_layout_mn_major[qkv_type, padded_depth, BN, swizzle_mode](), _tma_desc_tile_layout[qkv_type, 2, IndexList[2, DType.int64](BN, padded_depth, Tuple[]()), False, swizzle_mode](), False], q_smem: UnsafePointer[Scalar[qkv_type], address_space=AddressSpace(3), alignment=128], kv_smem: UnsafePointer[Scalar[qkv_type], address_space=AddressSpace(3), alignment=128], produced_mbar_kv: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], consumed_mbar_kv: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], produced_mbar_q: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], consumed_mbar_q: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8], kv_lut: KVLUTType, initial_position: MHAPosition[BM, BN, depth, padded_depth, num_heads, group, _is_decoding[MaxSeqLenType]()], partition: PartitionType, scheduler: SchedulerType, mask: MaskType, tile_summary: MHATileSummary[ValidLengthType], tile_state_arg: MHATileState, max_seq_len: MaxSeqLenType, num_keys_arg: UInt32, kv_input_row_offsets: KVInputRowOffsetsType)

Was this page helpful?