Skip to main content

Mojo function

mla_prefill_plan

mla_prefill_plan[cache_t: KVCacheT](buffer_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], buffer_lengths: LayoutTensor[DType.int32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_cache: cache_t, buffer_token_size: UInt32, ctx: DeviceContext)

This calls a GPU kernel that plans how to process a batch of sequences with varying lengths using a fixed-size buffer.

Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer_token_size.

For each chunk (iteration), it calculates: 1. Buffer offsets for each sequence in each chunk 2. Cache offsets for each sequence in each chunk 3. Total buffer lengths for each processing iteration

Was this page helpful?