Skip to main content
Log in

Mojo function

mha_single_batch_pipelined

mha_single_batch_pipelined[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, config: MHAConfig, group: Int = 1, use_score_mod: Bool = False](q_ptr: UnsafePointer[SIMD[q_type, 1]], k: k_t, v: v_t, output_ptr: UnsafePointer[SIMD[output_type, 1]], scale: SIMD[float32, 1], seq_len: Int, max_seq_len: Int, start_pos: SIMD[uint32, 1], num_keys: Int, mask_tensor_col: Int, mask: mask_t, score_mod: score_mod_t, batch_idx: Int)

MHA for token gen where seqlen = 1 and num_keys >= 1.

The general data layout and steps conform to flash attention. Two exceptions:

1 Partition across B, H, and num_keys (TODO). The last one is split-K and will need a separate reduction kernel at the end.

2 Frist bmm becomes gemv and second bmm becomes gevm. TODO: use more optimized kernels for them