Mojo struct
MMATileBuffers
struct MMATileBuffers[mut: Bool, dtype: DType, layout: Layout, origin: Origin[mut], address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, //, _dtype: DType, /, smem_layout: Layout, reg_tile_layout: Layout, swizzle: Swizzle, tensor_type: AnyStruct[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]], thread_layout: Layout, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, stride: Int]
Manages memory for a single matrix (A or B) in GEMM computation.
This struct encapsulates all memory handling for a matrix, including:
- Shared memory allocation and tiling
- Register buffer allocation
- Data movement between memory levels (DRAM→local→shared)
Fields
- smem_tile (
LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
): - smem_warp_tile (
LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
): - load_reg_tile (
LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
): - gmem_iter (
LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], origin, address_space, element_layout, layout_int_type, linear_idx_type, masked if masked else _tile_is_masked[layout, block_rows, stride](), alignment, block_rows, block_cols]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, block_rows, stride]() if masked if masked else _tile_is_masked[layout, block_rows, stride]() else _tile_is_masked[LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], block_rows, block_cols]()]
): - scatter_gather (
IteratorScatterGatherAmd[thread_layout]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = IteratorScatterGatherAmd[thread_layout].__del__is_trivial if LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], origin, address_space, element_layout, layout_int_type, linear_idx_type, masked if masked else _tile_is_masked[layout, block_rows, stride](), alignment, block_rows, block_cols]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, block_rows, stride]() if masked if masked else _tile_is_masked[layout, block_rows, stride]() else _tile_is_masked[LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], block_rows, block_cols]()].__del__is_trivial if LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], origin, address_space, element_layout, layout_int_type, linear_idx_type, masked if masked else _tile_is_masked[layout, block_rows, stride](), alignment, block_rows, block_cols]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, block_rows, stride]() if masked if masked else _tile_is_masked[layout, block_rows, stride]() else _tile_is_masked[LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], block_rows, block_cols]()].__del__is_trivial if LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[_dtype, simd_width_of[_dtype]()]](), warp_rows, warp_cols]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols](), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial if LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial else LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()].__del__is_trivial
iter_type
alias iter_type = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], origin, address_space, element_layout, layout_int_type, linear_idx_type, masked if masked else _tile_is_masked[layout, block_rows, stride](), alignment, block_rows, block_cols]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, block_rows, stride]() if masked if masked else _tile_is_masked[layout, block_rows, stride]() else _tile_is_masked[LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], block_rows, block_cols]()]
MMARegTileType
alias MMARegTileType = LayoutTensor[_dtype, reg_tile_layout, MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
SMemTileType
alias SMemTileType = LayoutTensor[_dtype, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[_dtype, simd_width_of[_dtype]()]]()]
Methods
__init__
__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_idx: Int, warp_k_idx: Int, block_idx: Int)
Initialize memory regions for a matrix based on warp coordinates.
Args:
- tensor (
LayoutTensor
): The tensor to load from global memory. - warp_idx (
Int
): The warp index within the computation grid (used for MMA operations). - warp_k_idx (
Int
): The warp index within the computation grid (used for MMA operations). - block_idx (
Int
): The block index within the computation grid (used for warp tiling).
copy_to_smem
copy_to_smem(self)
Copy data from thread-local memory to shared memory.
Uses structured thread cooperation to efficiently transfer data.
load_from_dram
load_from_dram(mut self)
Load data from global memory (DRAM) to thread-local memory.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!