Skip to main content

Mojo struct

MHAPosition

@register_passable(trivial) struct MHAPosition[BM: Int, BN: Int, depth: Int, padded_depth: Int, q_num_heads: Int, group: Int, decoding: Bool]

Position of the MHA-kernel. When decoding=False, q_head_stride == q_num_heads. When decoding=True, q_head_stride == 1.

Fields

  • q_row (UInt32):
  • q_col (UInt32):
  • q_out_offset (Int):
  • num_keys (UInt32):
  • start_pos (UInt32):
  • seq_len (UInt32):
  • head_idx (UInt32):
  • prompt_offset (UInt32):
  • prompt_idx (UInt32):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

q_output_gmem_layout

alias q_output_gmem_layout = Layout.__init__(IntTuple[__origin_of()](BM, depth), IntTuple[__origin_of()](depth if decoding else (depth * q_num_heads), 1))

q_stride

alias q_stride = depth if decoding else (depth * q_num_heads)

Methods

__init__

__init__(q_row: UInt32, q_col: UInt32, q_out_offset: Int, num_keys: UInt32, start_pos: UInt32, seq_info: SeqInfo) -> Self

__eq__

__eq__(self, other: Self) -> Bool

Returns:

Bool

__ne__

__ne__(self, other: Self) -> Bool

Returns:

Bool

q_head_idx

q_head_idx(self) -> UInt32

Returns:

UInt32

kv_head_idx

kv_head_idx(self) -> UInt32

Returns:

UInt32

write_to

write_to(self, mut writer: T)

q_tile_num_rows

q_tile_num_rows(self) -> UInt32

Returns:

UInt32

q_out_gmem_tensor

q_out_gmem_tensor[dtype: DType](self, ptr: UnsafePointer[Scalar[dtype]]) -> LayoutTensor[dtype, Layout.__init__(IntTuple[__origin_of()](BM, depth), IntTuple[__origin_of()](depth if decoding else (depth * q_num_heads), 1)), MutableAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]

Returns:

LayoutTensor

mask_status

mask_status[mask_t: MHAMask](self, mask: mask_t, kv_tile_start_row: UInt32) -> TileMaskStatus

Returns:

TileMaskStatus

exp_sum_qk_max_ptr

exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: UInt32) -> Tuple[UnsafePointer[Scalar[partition_t.accum_dtype]], UnsafePointer[Scalar[partition_t.accum_dtype]]]

Returns:

Tuple

get_start_and_end_for_partitions

get_start_and_end_for_partitions[partition_t: MHAPartitionScheme, //, BN: Int](self, partition: partition_t) -> Tuple[UInt32, UInt32]

Returns:

Tuple

Was this page helpful?