Mojo function
flash_attention_kv_cache
flash_attention_kv_cache[type: DType, cache_t: KVCacheT, //](q: NDBuffer[type, 4, origin, shape, strides], k: cache_t, v: cache_t, mask: NDBuffer[type, rank, origin, shape, strides], scale: SIMD[float32, 1], output: NDBuffer[type, 4, origin, shape, strides])
flash_attention_kv_cache[type: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: NDBuffer[type, 4, origin, shape, strides], k: cache_t, v: cache_t, mask: mask_t, scale: SIMD[float32, 1], output: NDBuffer[type, 4, origin, shape, strides])
flash_attention_kv_cache[type: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: NDBuffer[type, 3, origin, shape, strides], q_input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], kv_input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], k: cache_t, v: cache_t, mask: mask_t, scale: SIMD[float32, 1], output: NDBuffer[type, 3, origin, shape, strides])
Entrypoint for ragged tensors.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!