Mojo struct
MHAPosition
@register_passable(trivial)
struct MHAPosition[BM: Int, BN: Int, depth: Int, num_heads: Int, group: Int, decoding: Bool]
Position of the MHA-kernel. When decoding=False
, q_head_stride == num_heads
. When decoding=True
, q_head_stride == 1
.
Aliases
q_stride = depth if decoding else (depth * num_heads)
:q_output_gmem_layout = __init__[::Origin[::Bool(IntTuple(BM, depth), IntTuple(depth if decoding else (depth * num_heads), 1))
:
Fields
- q_out_offset (
Int
): - num_keys (
SIMD[uint32, 1]
): - start_pos (
SIMD[uint32, 1]
): - seq_len (
SIMD[uint32, 1]
): - head_idx (
SIMD[uint32, 1]
): - prompt_offset (
SIMD[uint32, 1]
): - prompt_idx (
SIMD[uint32, 1]
):
Implemented traits
AnyType
,
Copyable
,
ExplicitlyCopyable
,
Movable
,
UnknownDestructibility
Methods
__init__
__init__(q_out_offset: Int, num_keys: SIMD[uint32, 1], start_pos: SIMD[uint32, 1], seq_info: SeqInfo) -> Self
__eq__
__eq__(self, other: Self) -> Bool
__ne__
__ne__(self, other: Self) -> Bool
q_head_idx
q_head_idx(self) -> SIMD[uint32, 1]
kv_head_idx
kv_head_idx(self) -> SIMD[uint32, 1]
write_to
write_to[W: Writer](self, mut writer: W)
q_tile_num_rows
q_tile_num_rows(self) -> SIMD[uint32, 1]
q_out_gmem_tensor
q_out_gmem_tensor[dtype: DType](self, ptr: UnsafePointer[SIMD[dtype, 1]]) -> LayoutTensor[dtype, __init__[::Origin[::Bool(IntTuple(BM, depth), IntTuple(depth if decoding else (depth * num_heads), 1)), MutableAnyOrigin, layout_int_type=int32, linear_idx_type=int32, masked=True]
mask_status
mask_status[mask_t: MHAMask](self, mask: mask_t, kv_tile_start_row: SIMD[uint32, 1]) -> TileMaskStatus
exp_sum_qk_max_ptr
exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: SIMD[uint32, 1]) -> Tuple[UnsafePointer[SIMD[get_vtable_entry(:trait<_nn::_mha_sm90::_MHAPartitionScheme> partition_t, "accum_dtype"), 1]], UnsafePointer[SIMD[get_vtable_entry(:trait<_nn::_mha_sm90::_MHAPartitionScheme> partition_t, "accum_dtype"), 1]]]
get_start_and_end_for_partitions
get_start_and_end_for_partitions[partition_t: MHAPartitionScheme, //, BN: Int](self, partition: partition_t) -> Tuple[SIMD[uint32, 1], SIMD[uint32, 1]]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!