Mojo function
mma
mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: UInt32 = 1](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind])
Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.
Parameters:
- kind (
UMMAKind
): Data type of the matrices. - cta_group (
Int
): Number of ctas used by MMA. - c_scale (
UInt32
): Scale factor for the C matrix, 0 or 1.
Args:
- a_desc (
MMASmemDescriptor
): The descriptor for the A matrix. - b_desc (
MMASmemDescriptor
): The descriptor for the B matrix. - c_tmem (
UInt32
): The address of the C matrix in the tensor memory. - inst_desc (
UMMAInsDescriptor
): The descriptor for the MMA instruction.
mma[kind: UMMAKind, //, cta_group: Int = 1, /](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind], c_scale: UInt32)
Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.
Parameters:
Args:
- a_desc (
MMASmemDescriptor
): The descriptor for the A matrix. - b_desc (
MMASmemDescriptor
): The descriptor for the B matrix. - c_tmem (
UInt32
): The address of the C matrix in the tensor memory. - inst_desc (
UMMAInsDescriptor
): The descriptor for the MMA instruction. - c_scale (
UInt32
): Scale factor for the C matrix. Any non-zero value is translated to1
.
mma[kind: UMMAKind, //, cta_group: Int = 1, /](a_desc: UInt32, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind], c_scale: UInt32)
Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.
Parameters:
Args:
- a_desc (
UInt32
): The descriptor for the A matrix. - b_desc (
MMASmemDescriptor
): The descriptor for the B matrix. - c_tmem (
UInt32
): The address of the C matrix in the tensor memory. - inst_desc (
UMMAInsDescriptor
): The descriptor for the MMA instruction. - c_scale (
UInt32
): Scale factor for the C matrix. Any non-zero value is interpreted as1
.
mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: UInt32 = 1](a_desc: UInt32, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind])
Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction.
Parameters:
- kind (
UMMAKind
): Data type of the matrices. - cta_group (
Int
): Number of ctas used by MMA. - c_scale (
UInt32
): Scale factor for the C matrix, 0 or 1.
Args:
- a_desc (
UInt32
): The descriptor for the A matrix. - b_desc (
MMASmemDescriptor
): The descriptor for the B matrix. - c_tmem (
UInt32
): The address of the C matrix in the tensor memory. - inst_desc (
UMMAInsDescriptor
): The descriptor for the MMA instruction.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!