Mojo function
gemv_split_k
gemv_split_k[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, simd_width: UInt, tile_m: UInt, tile_n: UInt, num_threads: UInt, elementwise_lambda_fn: OptionalReg[fn[DType, Int, Int](Index[2], SIMD[$0, $1]) capturing -> None] = OptionalReg[fn[DType, Int, Int](Index[2], SIMD[$0, $1]) capturing -> None]({:i1 0, 1}), s_type: DType = get_accum_type[::DType,::DType]()](output: NDBuffer[c_type, 2, MutableAnyOrigin, c_shape], act: NDBuffer[a_type, 2, MutableAnyOrigin, a_shape], weight: NDBuffer[b_type, 2, MutableAnyOrigin, b_shape], m: UInt, n: UInt, k: UInt)
GEMV with tiling in K dimension. Assuming the B (weight) matrix is transposed i.e. row major N x K, this kernel implements a vector (1 x K) times a matrix (N x K).
The impl can actually handle M > 1 but it's only optimal fro tiny M. We use it for M = 1 only.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!