Mojo function
multistage_mma_q
multistage_mma_q[BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, num_pipeline_stages: Int, transpose_b: Bool, group_size: Int, pack_factor: Int, c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_smem_layout: Layout, b_type: DType, b_layout: Layout, b_smem_layout: Layout, scales_type: DType, scales_layout: Layout, scales_smem_layout: Layout, /, *, swizzle_a: Bool = True, static_num_iters: Dim = Dim(-31337), prefetch_init: Bool = True, continue_prefetch_b: Bool = False, transpose_b_next: Bool = False, b_next_gmem_layout: Layout = Layout(), b_next_smem_layout: Layout = Layout(), next_op_b_iter_alignment: Int = alignof[::DType,__mlir_type.!kgen.target]()](c: LayoutTensor[c_type, c_layout, origin, address_space=AddressSpace(5)], a_iter_arg: LayoutTensorIter[type, a_layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b_iter_arg: LayoutTensorIter[b_type, b_layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], a_smem_iter_arg: LayoutTensorIter[a_type, a_smem_layout, origin, address_space=AddressSpace(3), alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, origin, address_space=AddressSpace(3), alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], scales_smem_iter_arg: LayoutTensorIter[scales_type, scales_smem_layout, origin, address_space=AddressSpace(3), alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], scales_iter_arg: LayoutTensorIter[scales_type, scales_layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], num_iters: Int, /, *, num_b_rows: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!