Mojo function
mma
mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: SIMD[uint32, 1] = __init__[__mlir_type.!pop.int_literal](1)](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: SIMD[uint32, 1], inst_desc: UMMAInsDescriptor[kind])
Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.
Parameters:
- kind (
UMMAKind
): Data type of the matrices. - cta_group (
Int
): Number of ctas used by MMA. - c_scale (
SIMD[uint32, 1]
): Scale factor for the C matrix, 0 or 1.
Args:
- a_desc (
MMASmemDescriptor
): The descriptor for the A matrix. - b_desc (
MMASmemDescriptor
): The descriptor for the B matrix. - c_tmem (
SIMD[uint32, 1]
): The address of the C matrix in the tensor memory. - inst_desc (
UMMAInsDescriptor[kind]
): The descriptor for the MMA instruction.
mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: SIMD[uint32, 1] = __init__[__mlir_type.!pop.int_literal](1)](a_desc: SIMD[uint32, 1], b_desc: MMASmemDescriptor, c_tmem: SIMD[uint32, 1], inst_desc: UMMAInsDescriptor[kind])
Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.
Parameters:
- kind (
UMMAKind
): Data type of the matrices. - cta_group (
Int
): Number of ctas used by MMA. - c_scale (
SIMD[uint32, 1]
): Scale factor for the C matrix, 0 or 1.
Args:
- a_desc (
SIMD[uint32, 1]
): The descriptor for the A matrix. - b_desc (
MMASmemDescriptor
): The descriptor for the B matrix. - c_tmem (
SIMD[uint32, 1]
): The address of the C matrix in the tensor memory. - inst_desc (
UMMAInsDescriptor[kind]
): The descriptor for the MMA instruction.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!