Mojo function
rms_norm_kv_cache_ragged_paged
rms_norm_kv_cache_ragged_paged[type: DType, num_heads: Int, head_dim: Int, //, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool](kv_collection: PagedKVCacheCollection[type, KVCacheStaticParams(UInt(num_heads), UInt(head_dim)), page_size], gamma: NDBuffer[type, 1, origin, shape, strides], epsilon: SIMD[type, 1], weight_offset: SIMD[type, 1], layer_idx: SIMD[uint32, 1], total_seq_len: SIMD[uint32, 1], input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], context: DeviceContextPtr)
Performs RMSNorm in place on new entries in the key cache.
This is done by first creating the ragged tensor weight_shape
(total_seq_len, num_heads, head_dim) of the new token tensor.
To do this we need to pass in total_seq_len
on host.
Then, using input_row_offsets
we find the corresponding batch and token
index, and use that together with the static head and channel indices to
store to/load from the key cache.
This uses the input/output lambdas on the RMSNorm kernel.
This function could apply RMSNorm to a subset of dimensions in each head, determined by the size of the gamma tensor. In this case, it operates on a ragged tensor view of the key cache with shape (total_seq_len, num_heads, rms_norm_cols), where rms_norm_cols is the length of gamma and must be <= head_size.
weight_offset
is a constant offset argument added to the learned weights
at runtime. Here, we don't use any offset, so we pass in a zero scalar.
multiply_before_cast
is a boolean parameter that determines whether to
multiply the normalized values by the gamma tensor before casting to the
output type or not. We set it to True
by default.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!