Mojo function
mha_cross_gpu_naive
mha_cross_gpu_naive[cache_t: KVCacheT, mask_t: MHAMask, type: DType, q_shape: DimList, //, rank: Int](output: NDBuffer[type, rank, MutableAnyOrigin, shape, strides], q: NDBuffer[type, rank, MutableAnyOrigin, q_shape, strides], q_input_row_offsets: NDBuffer[uint32, 1, MutableAnyOrigin, shape, strides], q_max_seq_len: Int, k: cache_t, v: cache_t, kv_input_row_offsets: NDBuffer[uint32, 1, MutableAnyOrigin, shape, strides], mask_functor: mask_t, scale: SIMD[float32, 1], ctx: DeviceContext)
Naive cross attention on GPU.
Note that this assumes ragged tensor inputs and uses a mask functor.
Computes: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P * scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O).
B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory.
All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS.
This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num_groups.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!