Mojo struct
SM100TensorAccumulatorSS
@register_passable(trivial)
struct SM100TensorAccumulatorSS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, compute_BK: Int, num_softmax_threads: Int, swizzle_a: TensorMapSwizzle = 3, swizzle_b: TensorMapSwizzle = 3, *, transpose_b: Bool = True, cta_group: Int = 1, pipeline_stages: Int = 1]
Fields
- mbar (
UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]
): - pipeline (
PipelineState[pipeline_stages]
):
Implemented traits
AnyType
,
Copyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
a_offset
alias a_offset = MMAOperandOffsetFn[operand_type, BM, BK, swizzle_a, True, MMA_M, 16]()
a_t
alias a_t = MMASmemDescriptor
ab_t
alias ab_t = UMMADescriptorSS[operand_type]
accum_t
alias accum_t = accum_type
b_offset
alias b_offset = MMAOperandOffsetFn[operand_type, BN, BK, swizzle_b, transpose_b, MMA_N, 16]()
b_t
alias b_t = MMASmemDescriptor
c_t
alias c_t = TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), num_softmax_threads]
idesc
alias idesc = UMMAInsDescriptor.create[UMMAKind(2), accum_type, operand_type, operand_type, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()
MMA_K
alias MMA_K = 16
num_k_mmas
alias num_k_mmas = ((div_s compute_BK._mlir_value, 16) + -1) if ((compute_BK < 0) & (((rem_s compute_BK._mlir_value, 16) == 0) ^ True)) else (div_s compute_BK._mlir_value, 16)
num_m_blocks_per_warp
alias num_m_blocks_per_warp = 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)
num_m_mmas
alias num_m_mmas = 0 if (MMA_M == 0) else ((div_s BM._mlir_value, 1 if (MMA_M == 0) else MMA_M._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (MMA_M == 0) else MMA_M._mlir_value) == 0) ^ True) & ((BM < 0) ^ (MMA_M < 0))) else (div_s BM._mlir_value, 1 if (MMA_M == 0) else MMA_M._mlir_value)
num_n_mmas
alias num_n_mmas = 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value)
operand_t
alias operand_t = operand_type
smem_ptr_t
alias smem_ptr_t = UnsafePointer[Scalar[operand_type], address_space=AddressSpace(3)]
Methods
__init__
__init__(smem: UnsafePointer[SharedMemBarrier, address_space=AddressSpace(3), alignment=8]) -> Self
check_constraints
static check_constraints()
init
init(self)
mma_descriptors
static mma_descriptors[dtype_a: DType, dtype_b: DType](p_a: UnsafePointer[Scalar[dtype_a], address_space=AddressSpace(3)], p_b: UnsafePointer[Scalar[dtype_b], address_space=AddressSpace(3)]) -> UMMADescriptorSS[operand_type]
Returns:
UMMADescriptorSS
mma
mma(mut self, a: MMASmemDescriptor, b: MMASmemDescriptor, c_base: TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), num_softmax_threads], scale_c: UInt32)
wait_for_tmem
wait_for_tmem(self)
Wait for the accumulator tmem to finish being read.
wait_for_mma
wait_for_mma(self, c_base: TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), num_softmax_threads]) -> TMemAccumulator[accum_type, 0 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else ((div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) + -1) if ((((rem_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value) == 0) ^ True) & ((0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) ^ (BM < 0))) else (div_s BM._mlir_value, 1 if (0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) else 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value), MMA_N, 0 if (num_softmax_threads == 0) else ((div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((BM * 2) < 0) ^ (num_softmax_threads < 0))) else (div_s (BM * 2), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value), 0 if (MMA_N == 0) else ((div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) + -1) if ((((rem_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value) == 0) ^ True) & ((BN < 0) ^ (MMA_N < 0))) else (div_s BN._mlir_value, 1 if (MMA_N == 0) else MMA_N._mlir_value), num_softmax_threads]
Wait for the accumulator tmem to finish being read.
Returns:
tmem_arrive_init
tmem_arrive_init(self)
tmem_arrive
tmem_arrive(mut self)
Indicate that the accumulator is ready to be updated.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!