Skip to main content
Log in

Mojo function

rms_norm_kv_cache_ragged_continuous_batching

rms_norm_kv_cache_ragged_continuous_batching[type: DType, num_heads: Int, head_dim: Int, //, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool](kv_collection: ContinuousBatchingKVCacheCollection[type, KVCacheStaticParams(UInt(num_heads), UInt(head_dim))], gamma: NDBuffer[type, 1, origin, shape, strides], epsilon: SIMD[type, 1], weight_offset: SIMD[type, 1], layer_idx: SIMD[uint32, 1], total_seq_len: SIMD[uint32, 1], input_row_offsets: NDBuffer[uint32, 1, origin, shape, strides], context: DeviceContextPtr)

Performs RMSNorm in place on new entries in the key cache.

This is done by first creating the ragged tensor weight_shape (total_seq_len, num_heads, head_dim) of the new token tensor. To do this we need to pass in total_seq_len on host. Then, using input_row_offsets we find the corresponding batch and token index, and use that together with the static head and channel indices to store to/load from the key cache. This uses the input/output lambdas on the RMSNorm kernel.

This function could apply RMSNorm to a subset of dimensions in each head, determined by the size of the gamma tensor. In this case, it operates on a ragged tensor view of the key cache with shape (total_seq_len, num_heads, rms_norm_cols), where rms_norm_cols is the length of gamma and must be <= head_size.

weight_offset is a constant offset argument added to the learned weights at runtime. Here, we don't use any offset, so we pass in a zero scalar.

multiply_before_cast is a boolean parameter that determines whether to multiply the normalized values by the gamma tensor before casting to the output type or not. We set it to True by default.