Skip to main content
Log in

Mojo function

mha_cross_gpu_naive

mha_cross_gpu_naive[cache_t: KVCacheT, mask_t: MHAMask, type: DType, q_shape: DimList, //, rank: Int](output: NDBuffer[type, rank, MutableAnyOrigin, shape, strides], q: NDBuffer[type, rank, MutableAnyOrigin, q_shape, strides], q_input_row_offsets: NDBuffer[uint32, 1, MutableAnyOrigin, shape, strides], q_max_seq_len: Int, k: cache_t, v: cache_t, kv_input_row_offsets: NDBuffer[uint32, 1, MutableAnyOrigin, shape, strides], mask_functor: mask_t, scale: SIMD[float32, 1], ctx: DeviceContext)

Naive cross attention on GPU.

Note that this assumes ragged tensor inputs and uses a mask functor.

Computes: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P * scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O).

B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory.

All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS.

This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num_groups.