Mojo struct
MMATileBuffers
struct MMATileBuffers[tensor_origin: ImmutableOrigin, //, smem_layout: Layout, /, tensor_type: AnyStruct[LayoutTensor[, , , address_space=, element_layout=, layout_int_type=, linear_idx_type=, masked=, alignment=]], thread_layout: Layout, block_rows: Int, warp_rows: Int, stride: Int, num_mmas: Int, mma_type: AnyStruct[AMD_MMA[, , , , , , , , , , , ]]]
Manages memory for a single matrix (A or B) in GEMM computation.
This struct encapsulates all memory handling for a matrix, including:
- Shared memory allocation and tiling
- Register buffer allocation
- Data movement between memory levels (DRAM→local→shared)
Fields
- shared_mem_tile (
LayoutTensor[in_type, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[in_type, simd_width]]()]
): - shared_mem_warp_tile (
LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, smem_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(smem_layout, AddressSpace(3)), _get_index_type(smem_layout, AddressSpace(3)), False, align_of[SIMD[in_type, simd_width]](), warp_rows, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(smem_layout, AddressSpace(3)), linear_idx_type=_get_index_type(smem_layout, AddressSpace(3)), masked=_tile_is_masked[smem_layout, warp_rows, WK](), alignment=align_of[SIMD[in_type, simd_width]]()]
): - load_reg_tile (
LayoutTensor[in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width]]()]
): - mma_reg_tile (
StaticTuple[LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_mmas), simd_width), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_mmas), simd_width), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width]](), 0 if (num_k_tiles == 0) else (div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int apply(:!lit.generator<[1]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@stdlib::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm *[0,0]> read_mem) -> !lit.struct<@stdlib::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple[$0])"<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@stdlib::@builtin::@int::@Int>, ?, "__result__": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm *[0,0]> read_mem, "_idx": !lit.struct<@stdlib::@builtin::@int::@Int>, ?, "__result__": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = *[0,0]}>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple[$0],::Int)"<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>), store_to_mem(#lit.struct.extract<:@layout::@layout::@Layout apply_result_slot(:!lit.generator<[1]("dims": !kgen.variadic<@stdlib::@builtin::@int::@Int> pos_vararg, ?, "__result__": !lit.ref<@layout::@layout::@Layout, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[1]("dims": !kgen.variadic<@stdlib::@builtin::@int::@Int> pos_vararg, ?, "__result__": !lit.ref<@layout::@layout::@Layout, mut *[0,0]> byref_result) -> !kgen.none> @layout::@layout::@Layout::@"row_major(::Int*)"), [{_mlir_value = mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int *"num_k_tiles
17", "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int num_mmas, "_mlir_value">)}, *"simd_width20"]), "shape">), {0}))), "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int *"num_k_tiles
17", "_mlir_value">, 0), {1}, *"num_k_tiles17"), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int apply(:!lit.generator<[1]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@stdlib::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm *[0,0]> read_mem) -> !lit.struct<@stdlib::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple[$0])"<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@stdlib::@builtin::@int::@Int>, ?, "__result__": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm *[0,0]> read_mem, "_idx": !lit.struct<@stdlib::@builtin::@int::@Int>, ?, "__result__": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = *[0,0]}>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple[$0],::Int)"<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>), store_to_mem(#lit.struct.extract<:@layout::@layout::@Layout apply_result_slot(:!lit.generator<[1]("dims": !kgen.variadic<@stdlib::@builtin::@int::@Int> pos_vararg, ?, "__result__": !lit.ref<@layout::@layout::@Layout, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[1]("dims": !kgen.variadic<@stdlib::@builtin::@int::@Int> pos_vararg, ?, "__result__": !lit.ref<@layout::@layout::@Layout, mut *[0,0]> byref_result) -> !kgen.none> @layout::@layout::@Layout::@"row_major(::Int*)"), [{_mlir_value = mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int *"num_k_tiles
17", "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int num_mmas, "_mlir_value">)}, *"simd_width20"]), "shape">), {0}))), "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int *"num_k_tiles
17", "_mlir_value">, 0), {1}, *"num_k_tiles17"), "_mlir_value">) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_mmas), simd_width).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else div_s(#lit.struct.extract<:@stdlib::@builtin::@int::@Int apply(:!lit.generator<[1]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm #lit.comptime.origin> read_mem) -> !lit.struct<@stdlib::@builtin::@int::@Int>> rebind(:!lit.generator<[1]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm *[0,0]> read_mem) -> !lit.struct<@stdlib::@builtin::@int::@Int>> @layout::@int_tuple::@IntTuple::@"value(::IntTuple[$0])"<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<@stdlib::@builtin::@int::@Int>, ?, "__result__": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm *[0,0]> read_mem, "_idx": !lit.struct<@stdlib::@builtin::@int::@Int>, ?, "__result__": !lit.ref<@layout::@int_tuple::@IntTuple<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = *[0,0]}>, mut *[0,1]> byref_result) -> !kgen.none> @layout::@int_tuple::@IntTuple::@"__getitem__(::IntTuple[$0],::Int)"<:@stdlib::@builtin::@type_aliases::@Origin<:@stdlib::@builtin::@bool::@Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>), store_to_mem(#lit.struct.extract<:@layout::@layout::@Layout apply_result_slot(:!lit.generator<[1]("dims": !kgen.variadic<@stdlib::@builtin::@int::@Int> pos_vararg, ?, "__result__": !lit.ref<@layout::@layout::@Layout, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[1]("dims": !kgen.variadic<@stdlib::@builtin::@int::@Int> pos_vararg, ?, "__result__": !lit.ref<@layout::@layout::@Layout, mut *[0,0]> byref_result) -> !kgen.none> @layout::@layout::@Layout::@"row_major(::Int*)"), [{_mlir_value = mul(#lit.struct.extract<:@stdlib::@builtin::@int::@Int *"num_k_tiles
17", "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int num_mmas, "_mlir_value">)}, *"simd_width20"]), "shape">), {0}))), "_mlir_value">, #lit.struct.extract<:@stdlib::@builtin::@int::@Int cond(eq(#lit.struct.extract<:@stdlib::@builtin::@int::@Int *"num_k_tiles
17", "_mlir_value">, 0), {1}, *"num_k_tiles17"), "_mlir_value">), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width]]()], num_k_tiles]
): - gmem_iter (
LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], origin, address_space, element_layout, layout_int_type, linear_idx_type, masked if masked else _tile_is_masked[layout, block_rows, stride](), alignment, block_rows, BK]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, block_rows, stride]() if masked if masked else _tile_is_masked[layout, block_rows, stride]() else _tile_is_masked[LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], block_rows, BK]()]
): - global_offset (
UInt
): - tensor (
Pointer[tensor_type, tensor_origin]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
iter_type
alias iter_type = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], origin, address_space, element_layout, layout_int_type, linear_idx_type, masked if masked else _tile_is_masked[layout, block_rows, stride](), alignment, block_rows, BK]()[0], origin, address_space=address_space, axis=OptionalReg[Int]({:@stdlib::@builtin::@int::@Int {1}, 0}), layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, block_rows, stride]() if masked if masked else _tile_is_masked[layout, block_rows, stride]() else _tile_is_masked[LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, block_rows, stride]()[0], block_rows, BK]()]
MMARegTileType
alias MMARegTileType = LayoutTensor[in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width), MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width]]()]
SharedMemTileType
alias SharedMemTileType = LayoutTensor[in_type, smem_layout, MutableAnyOrigin, address_space=AddressSpace(3), alignment=align_of[SIMD[in_type, simd_width]]()]
Methods
__init__
__init__(out self, ref [tensor_origin] tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_idx: Int, warp_k_idx: Int, block_idx: Int)
Initialize memory regions for a matrix based on warp coordinates.
Args:
- tensor (
LayoutTensor
): The tensor to load from global memory. - warp_idx (
Int
): The warp index within the computation grid (used for MMA operations). - warp_k_idx (
Int
): The warp index within the computation grid (used for MMA operations). - block_idx (
Int
): The block index within the computation grid (used for warp tiling).
copy_to_shared
copy_to_shared(self)
Copy data from thread-local memory to shared memory.
Uses structured thread cooperation to efficiently transfer data.
load_from_dram
load_from_dram(mut self)
Load data from global memory (DRAM) to thread-local memory.
get_reg_tile
get_reg_tile[k_tile_idx: Int](self) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_k_tiles * num_mmas), simd_width), MutableAnyOrigin, AddressSpace(5), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.row_major((num_k_tiles * num_mmas), simd_width), AddressSpace(5)), _get_index_type(Layout.row_major((num_k_tiles * num_mmas), simd_width), AddressSpace(5)), False, align_of[SIMD[in_type, simd_width]](), 0 if (num_k_tiles == 0) else (div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int apply(:!lit.generator<[1]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm #lit.comptime.origin> read_mem) -> !lit.struct<_stdlib::_builtin::_int::_Int>> rebind(:!lit.generator<[1]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm *[0,0]> read_mem) -> !lit.struct<_stdlib::_builtin::_int::_Int>> _layout::_int_tuple::_IntTuple::_"value(::IntTuple[$0])"<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<_stdlib::_builtin::_int::_Int>, ?, "__result__": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm *[0,0]> read_mem, "_idx": !lit.struct<_stdlib::_builtin::_int::_Int>, ?, "__result__": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = *[0,0]}>, mut *[0,1]> byref_result) -> !kgen.none> _layout::_int_tuple::_IntTuple::_"__getitem__(::IntTuple[$0],::Int)"<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>), store_to_mem(#lit.struct.extract<:_layout::_layout::_Layout apply_result_slot(:!lit.generator<[1]("dims": !kgen.variadic<_stdlib::_builtin::_int::_Int> pos_vararg, ?, "__result__": !lit.ref<_layout::_layout::_Layout, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[1]("dims": !kgen.variadic<_stdlib::_builtin::_int::_Int> pos_vararg, ?, "__result__": !lit.ref<_layout::_layout::_Layout, mut *[0,0]> byref_result) -> !kgen.none> _layout::_layout::_Layout::_"row_major(::Int*)"), [{_mlir_value = mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int *"num_k_tiles
17", "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int num_mmas, "_mlir_value">)}, *"simd_width20"]), "shape">), {0}))), "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int *"num_k_tiles
17", "_mlir_value">, 0), {1}, *"num_k_tiles17"), "_mlir_value">) + -1) if (((rem_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int apply(:!lit.generator<[1]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm #lit.comptime.origin> read_mem) -> !lit.struct<_stdlib::_builtin::_int::_Int>> rebind(:!lit.generator<[1]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm *[0,0]> read_mem) -> !lit.struct<_stdlib::_builtin::_int::_Int>> _layout::_int_tuple::_IntTuple::_"value(::IntTuple[$0])"<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<_stdlib::_builtin::_int::_Int>, ?, "__result__": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm *[0,0]> read_mem, "_idx": !lit.struct<_stdlib::_builtin::_int::_Int>, ?, "__result__": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = *[0,0]}>, mut *[0,1]> byref_result) -> !kgen.none> _layout::_int_tuple::_IntTuple::_"__getitem__(::IntTuple[$0],::Int)"<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>), store_to_mem(#lit.struct.extract<:_layout::_layout::_Layout apply_result_slot(:!lit.generator<[1]("dims": !kgen.variadic<_stdlib::_builtin::_int::_Int> pos_vararg, ?, "__result__": !lit.ref<_layout::_layout::_Layout, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[1]("dims": !kgen.variadic<_stdlib::_builtin::_int::_Int> pos_vararg, ?, "__result__": !lit.ref<_layout::_layout::_Layout, mut *[0,0]> byref_result) -> !kgen.none> _layout::_layout::_Layout::_"row_major(::Int*)"), [{_mlir_value = mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int *"num_k_tiles
17", "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int num_mmas, "_mlir_value">)}, *"simd_width20"]), "shape">), {0}))), "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int *"num_k_tiles
17", "_mlir_value">, 0), {1}, *"num_k_tiles17"), "_mlir_value">) == 0) ^ True) & ((Layout.row_major((num_k_tiles * num_mmas), simd_width).shape[0].value[ComptimeOrigin]() < 0) ^ (num_k_tiles < 0))) else div_s(#lit.struct.extract<:_stdlib::_builtin::_int::_Int apply(:!lit.generator<[1]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm #lit.comptime.origin> read_mem) -> !lit.struct<_stdlib::_builtin::_int::_Int>> rebind(:!lit.generator<[1]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, imm *[0,0]> read_mem) -> !lit.struct<_stdlib::_builtin::_int::_Int>> _layout::_int_tuple::_IntTuple::_"value(::IntTuple[$0])"<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>), store_to_mem(apply_result_slot(:!lit.generator<[2]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm #lit.comptime.origin> read_mem, "_idx": !lit.struct<_stdlib::_builtin::_int::_Int>, ?, "__result__": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = #lit.comptime.origin}>, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[2]("self": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>, imm *[0,0]> read_mem, "_idx": !lit.struct<_stdlib::_builtin::_int::_Int>, ?, "__result__": !lit.ref<_layout::_int_tuple::_IntTuple<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = *[0,0]}>, mut *[0,1]> byref_result) -> !kgen.none> _layout::_int_tuple::_IntTuple::_"__getitem__(::IntTuple[$0],::Int)"<:_stdlib::_builtin::_type_aliases::_Origin<:_stdlib::_builtin::_bool::_Bool {:i1 0}> {_mlir_origin: origin<0> = {}}>), store_to_mem(#lit.struct.extract<:_layout::_layout::_Layout apply_result_slot(:!lit.generator<[1]("dims": !kgen.variadic<_stdlib::_builtin::_int::_Int> pos_vararg, ?, "__result__": !lit.ref<_layout::_layout::_Layout, mut #lit.comptime.origin> byref_result) -> !kgen.none> rebind(:!lit.generator<[1]("dims": !kgen.variadic<_stdlib::_builtin::_int::_Int> pos_vararg, ?, "__result__": !lit.ref<_layout::_layout::_Layout, mut *[0,0]> byref_result) -> !kgen.none> _layout::_layout::_Layout::_"row_major(::Int*)"), [{_mlir_value = mul(#lit.struct.extract<:_stdlib::_builtin::_int::_Int *"num_k_tiles
17", "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int num_mmas, "_mlir_value">)}, *"simd_width20"]), "shape">), {0}))), "_mlir_value">, #lit.struct.extract<:_stdlib::_builtin::_int::_Int cond(eq(#lit.struct.extract<:_stdlib::_builtin::_int::_Int *"num_k_tiles
17", "_mlir_value">, 0), {1}, *"num_k_tiles17"), "_mlir_value">), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5), alignment=align_of[SIMD[in_type, simd_width]]()]
Get a specific K-dimension tile from the register buffer.
Parameters:
- k_tile_idx (
Int
): The K-dimension tile index.
Returns:
LayoutTensor
: A tile view for the specified location in the register buffer.
load_tile_from_shared
load_tile_from_shared[k_tile_idx: Int, is_a: Bool](self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!