Skip to main content
Log in

Mojo function

flash_attention_kv_cache

flash_attention_kv_cache[type: DType, cache_t: KVCacheT, //](q: NDBuffer[type, 4, origin, shape, strides], k: cache_t, v: cache_t, mask: NDBuffer[type, rank, origin, shape, strides], scale: SIMD[float32, 1], output: NDBuffer[type, 4, origin, shape, strides])

flash_attention_kv_cache[type: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: NDBuffer[type, 4, origin, shape, strides], k: cache_t, v: cache_t, mask: mask_t, scale: SIMD[float32, 1], output: NDBuffer[type, 4, origin, shape, strides])

flash_attention_kv_cache[type: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: NDBuffer[type, 3, origin, shape, strides], q_input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], kv_input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], k: cache_t, v: cache_t, mask: mask_t, scale: SIMD[float32, 1], output: NDBuffer[type, 3, origin, shape, strides])

Entrypoint for ragged tensors.