Skip to main content

Mojo struct

RaggedMHAOperand

@register_passable(trivial) struct RaggedMHAOperand[dtype_: DType, layout: Layout, cache_layout: Layout]

An implementation for ragged NDBuffer arguments to MHA kernels.

Fields

  • buffer (LayoutTensor[dtype_, layout, MutableAnyOrigin]):
  • cache_row_offsets (LayoutTensor[DType.uint32, cache_layout, MutableAnyOrigin]):

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, MHAOperand, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

device_type

alias device_type = RaggedMHAOperand[dtype_, layout, cache_layout]

dtype

alias dtype = dtype_

Methods

__init__

__init__(buffer: LayoutTensor[dtype_, layout, MutableAnyOrigin], cache_row_offsets: LayoutTensor[DType.uint32, cache_layout, MutableAnyOrigin]) -> Self

get_type_name

static get_type_name() -> String

Returns:

String

get_device_type_name

static get_device_type_name() -> String

Returns:

String

block_paged_ptr

block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> UnsafePointer[Scalar[dtype_]]

Returns:

UnsafePointer

cache_length

cache_length(self, batch_idx: Int) -> Int

Returns:

Int

max_context_length

max_context_length(self) -> UInt32

Returns:

UInt32

row_idx

row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32

Returns the row idx when viewing the memory as a matrix.

Returns:

UInt32

col_idx

col_idx(self, head_idx: UInt32) -> UInt32

Returns the col idx when viewing the memory as a matrix.

Returns:

UInt32

create_tma_tile

create_tma_tile[tile_m: Int, tile_n: Int, swizzle_mode: TensorMapSwizzle, *, is_k_major: Bool](self, ctx: DeviceContext) -> TMATensorTile[dtype_, tile_layout_k_major[dtype_, tile_m, tile_n, swizzle_mode]() if is_k_major else tile_layout_mn_major[dtype_, tile_n, tile_m, swizzle_mode](), _tma_desc_tile_layout[dtype_, 2, IndexList[2, DType.int64](tile_m, tile_n, Tuple[]()), is_k_major, swizzle_mode](), is_k_major]

Creates a TMA tile for efficient GPU memory transfers.

Returns:

TMATensorTile

Was this page helpful?