Skip to main content

Mojo struct

TMemOperand

@register_passable(trivial) struct TMemOperand[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, num_softmax_threads: Int]

Fields

  • tmem_addr (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility, WriteableMMAOperandDescriptor

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

frag_size

alias frag_size = 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)

reg_layout

alias reg_layout = RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads]

reg_tile_t

alias reg_tile_t = LayoutTensor[dtype, Layout.__init__(IntTuple.__init__[__origin_of()](IntTuple[__origin_of()](2, num_m_mmas), IntTuple[__origin_of()](((div_s 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value, 4) + -1) if ((0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) & (((rem_s 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value, 4) == 0) ^ True)) else (div_s 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value, 4), num_n_mmas), Tuple[]()), IntTuple.__init__[__origin_of()](IntTuple[__origin_of()](2, 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)), IntTuple[__origin_of()](4, (0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) * num_m_mmas)), Tuple[]())), MutableAnyOrigin, address_space=AddressSpace(5), element_layout=Layout.row_major(1, 2)]

vec_output_layout

alias vec_output_layout = Layout.__init__(IntTuple.__init__[__origin_of()](IntTuple[__origin_of()](2, num_m_mmas), IntTuple[__origin_of()](((div_s 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value, 4) + -1) if ((0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) < 0) & (((rem_s 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value, 4) == 0) ^ True)) else (div_s 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)._mlir_value, 4), num_n_mmas), Tuple[]()), IntTuple.__init__[__origin_of()](IntTuple[__origin_of()](2, 0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value)), IntTuple[__origin_of()](4, (0 if (num_softmax_threads == 0) else ((div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) + -1) if ((((rem_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) == 0) ^ True) & (((MMA_M * MMA_N) < 0) ^ (num_softmax_threads < 0))) else (div_s (MMA_M * MMA_N), 1 if (num_softmax_threads == 0) else num_softmax_threads._mlir_value) * num_m_mmas)), Tuple[]()))

Methods

__init__

__init__(tmem_addr: UInt32) -> Self

offset

offset[m_mma: Int, k_mma: Int](self) -> UInt32

Returns:

UInt32

copy_from

copy_from[src_type: DType, src_layout: Layout, src_element_layout: Layout, //](self, src: LayoutTensor[src_type, src_layout, MutableAnyOrigin, address_space=AddressSpace(5), element_layout=src_element_layout])

copy_to

copy_to[dst_type: DType, dst_layout: Layout, dst_element_layout: Layout, //](self, dst: LayoutTensor[dst_type, dst_layout, MutableAnyOrigin, address_space=AddressSpace(5), element_layout=dst_element_layout])

Was this page helpful?