Skip to main content
Log in

Mojo function

flash_attention_split_kv

flash_attention_split_kv[type: DType, rank: Int, mask_rank: Int, //, input_k_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_v_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_k_cache_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_v_cache_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_mask_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0]](q: NDBuffer[type, rank, origin, shape, strides], k_shape: Index[rank], v_shape: Index[rank], k_cache_shape: Index[(rank + 1)], v_cache_shape: Index[(rank + 1)], mask_shape: Index[mask_rank], output: NDBuffer[type, rank, origin, shape, strides], scale: SIMD[float32, 1])

Variant of flash attention that takes the previous KV cache input_{k,v}_cache_fn and the current KV tensors input_k_fn and input_v_fn as separate arguments.

This works around the fact that fusion can't currently look through concat. So this kernel does an in-place concat fusion by changing the input lambdas input_{k,v}_cache_fn_wrapper to take previous sequence KV elements from the KV cache, and current KV elements from tensors k and v.