Mojo struct
SM100TensorAccumulatorTS
@register_passable(trivial)
struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = 3, transpose_b: Bool = True, cta_group: Int = 1]
Fields
- mbar (
UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]
): - phase (
UInt32
):
Implemented traits
AnyType
,
Copyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
a_frag_size
alias a_frag_size = 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * 16), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * 16), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * 16) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * 16), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)
a_t
alias a_t = TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), BK, 16, num_softmax_threads]
ab_t
alias ab_t = UMMADescriptorTS[operand_type, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), MMA_M=0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]
accum_t
alias accum_t = accum_type
b_offset
alias b_offset = MMAOperandOffsetFn[operand_type, BN, BK, swizzle_b, transpose_b, MMA_N, 16]()
b_t
alias b_t = MMASmemDescriptor
c_frag_size
alias c_frag_size = 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)
c_t
alias c_t = TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), num_softmax_threads]
idesc
alias idesc = UMMAInsDescriptor.create[UMMAKind(2), accum_type, operand_type, operand_type, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()
MMA_K
alias MMA_K = 16
num_k_mmas
alias num_k_mmas = ((div_s BK._mlir_value, 16) + -1) if ((BK < 0) & (((rem_s BK._mlir_value, 16) == 0) ^ True)) else (div_s BK._mlir_value, 16)
num_m_blocks_per_warp
alias num_m_blocks_per_warp = 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)
num_m_mmas
alias num_m_mmas = 0 if (MMA_M == 0) else ((div_s BM._mlir_value, 1 if (MMA_M == 0) else MMA_M._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (MMA_M == 0) else MMA_M._mlir_value) == 0) ^ True) & ((BM < 0) ^ (MMA_M < 0))) else (div_s BM._mlir_value, 1 if (MMA_M == 0) else MMA_M._mlir_value)
num_n_mmas
alias num_n_mmas = 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value)
operand_t
alias operand_t = operand_type
smem_ptr_t
alias smem_ptr_t = UnsafePointer[Scalar[operand_type], address_space=AddressSpace(3)]
Methods
__init__
__init__(smem: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]) -> Self
check_constraints
static check_constraints()
init
init(self)
a_mma_descriptor
static a_mma_descriptor(a_tmem: UInt32) -> TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), BK, 16, num_softmax_threads]
Returns:
b_mma_descriptor
static b_mma_descriptor[dtype_b: DType](p_b: UnsafePointer[Scalar[dtype_b], address_space=AddressSpace(3)]) -> MMASmemDescriptor
Returns:
mma
mma(self, a: TMemOperand[operand_type, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), num_softmax_threads], c_scale: UInt32)
wait
wait(mut self, idx: UInt32)
wait_for_mma
wait_for_mma(mut self)
Wait for the mma to be complete.
wait_for_tmem
wait_for_tmem(mut self)
Wait for the output
and A
tmem to be ready.
tmem_arrive
tmem_arrive(self)
Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!