Mojo struct
MmaOpAMD
struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, BK: Int, WK: Int]
Fields
- a_reg_tile (
StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]
): - b_reg_tile (
StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]
): - out_reg_tile (
LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[out_type, simd_width_of[out_type]()]]()]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[out_type, simd_width_of[out_type]()]]()].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_n_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial if StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial else StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_m_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles].__del__is_trivial
alignment
alias alignment = align_of[SIMD[in_type, simd_width_of[in_type]()]]()
out_reg_layout
alias out_reg_layout = Layout.row_major((num_m_mmas * num_n_mmas), 4)
OutRegTileType
alias OutRegTileType = LayoutTensor[out_type, Layout.row_major((num_m_mmas * num_n_mmas), 4), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[out_type, simd_width_of[out_type]()]]()]
reg_tile_layout
alias reg_tile_layout[num_mmas: Int] = Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]())
Parameters
- num_mmas (
Int
):
RegTileFragType
alias RegTileFragType[num_mmas: Int] = StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width_of[in_type]()]](), (Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()).shape[0].value[ComptimeOrigin]() // num_k_tiles), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()], num_k_tiles]
Parameters
- num_mmas (
Int
):
RegTileType
alias RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width_of[in_type]()), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width_of[in_type]()]]()]
Parameters
- num_mmas (
Int
):
simd_width
alias simd_width = simd_width_of[in_type]()
swizzle
alias swizzle = Swizzle(3, 0, 1)
tensor_core_mma
alias tensor_core_mma = TensorCoreKGroup[out_type, in_type, shape, k_group_size, transpose_b]()
Methods
__init__
__init__(out self)
smem_tile_layout
mma
mma[k_tile_idx: Int](self)
load_tile_fragment
load_tile_fragment[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(layout, AddressSpace(3)), _get_index_type(layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(layout, AddressSpace(3)), linear_idx_type=_get_index_type(layout, AddressSpace(3)), masked=_tile_is_masked[layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()], b_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(layout, AddressSpace(3)), _get_index_type(layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(layout, AddressSpace(3)), linear_idx_type=_get_index_type(layout, AddressSpace(3)), masked=_tile_is_masked[layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()])
reset_accumulator
reset_accumulator(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!