Skip to main content
Log in

Mojo struct

MHAPosition

@register_passable(trivial) struct MHAPosition[BM: Int, BN: Int, depth: Int, num_heads: Int, group: Int, decoding: Bool]

Position of the MHA-kernel. When decoding=False, q_head_stride == num_heads. When decoding=True, q_head_stride == 1.

Aliases

  • q_stride = depth if decoding else (depth * num_heads):
  • q_output_gmem_layout = __init__[::Origin[::Bool(IntTuple(BM, depth), IntTuple(depth if decoding else (depth * num_heads), 1)):

Fields

  • q_out_offset (Int):
  • num_keys (SIMD[uint32, 1]):
  • start_pos (SIMD[uint32, 1]):
  • seq_len (SIMD[uint32, 1]):
  • head_idx (SIMD[uint32, 1]):
  • prompt_offset (SIMD[uint32, 1]):
  • prompt_idx (SIMD[uint32, 1]):

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, Movable, UnknownDestructibility

Methods

__init__

__init__(q_out_offset: Int, num_keys: SIMD[uint32, 1], start_pos: SIMD[uint32, 1], seq_info: SeqInfo) -> Self

__eq__

__eq__(self, other: Self) -> Bool

__ne__

__ne__(self, other: Self) -> Bool

q_head_idx

q_head_idx(self) -> SIMD[uint32, 1]

kv_head_idx

kv_head_idx(self) -> SIMD[uint32, 1]

write_to

write_to[W: Writer](self, mut writer: W)

q_tile_num_rows

q_tile_num_rows(self) -> SIMD[uint32, 1]

q_out_gmem_tensor

q_out_gmem_tensor[dtype: DType](self, ptr: UnsafePointer[SIMD[dtype, 1]]) -> LayoutTensor[dtype, __init__[::Origin[::Bool(IntTuple(BM, depth), IntTuple(depth if decoding else (depth * num_heads), 1)), MutableAnyOrigin, layout_int_type=int32, linear_idx_type=int32, masked=True]

mask_status

mask_status[mask_t: MHAMask](self, mask: mask_t, kv_tile_start_row: SIMD[uint32, 1]) -> TileMaskStatus

exp_sum_qk_max_ptr

exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: SIMD[uint32, 1]) -> Tuple[UnsafePointer[SIMD[get_vtable_entry(:trait<_nn::_mha_sm90::_MHAPartitionScheme> partition_t, "accum_dtype"), 1]], UnsafePointer[SIMD[get_vtable_entry(:trait<_nn::_mha_sm90::_MHAPartitionScheme> partition_t, "accum_dtype"), 1]]]

get_start_and_end_for_partitions

get_start_and_end_for_partitions[partition_t: MHAPartitionScheme, //, BN: Int](self, partition: partition_t) -> Tuple[SIMD[uint32, 1], SIMD[uint32, 1]]