Mojo struct
MHAPosition
@register_passable(trivial)
struct MHAPosition[BM: Int, BN: Int, depth: Int, padded_depth: Int, q_num_heads: Int, group: Int, decoding: Bool]
Position of the MHA-kernel. When decoding=False
, q_head_stride == q_num_heads
. When decoding=True
, q_head_stride == 1
.
Fields
- q_row (
UInt32
): - q_col (
UInt32
): - q_out_offset (
Int
): - num_keys (
UInt32
): - start_pos (
UInt32
): - seq_len (
UInt32
): - head_idx (
UInt32
): - prompt_offset (
UInt32
): - prompt_idx (
UInt32
):
Implemented traits
AnyType
,
Copyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
q_output_gmem_layout
alias q_output_gmem_layout = Layout.__init__(IntTuple[__origin_of()](BM, depth), IntTuple[__origin_of()](depth if decoding else (depth * q_num_heads), 1))
q_stride
alias q_stride = depth if decoding else (depth * q_num_heads)
Methods
__init__
__init__(q_row: UInt32, q_col: UInt32, q_out_offset: Int, num_keys: UInt32, start_pos: UInt32, seq_info: SeqInfo) -> Self
__eq__
__ne__
q_head_idx
kv_head_idx
write_to
write_to(self, mut writer: T)
q_tile_num_rows
q_out_gmem_tensor
q_out_gmem_tensor[dtype: DType](self, ptr: UnsafePointer[Scalar[dtype]]) -> LayoutTensor[dtype, Layout.__init__(IntTuple[__origin_of()](BM, depth), IntTuple[__origin_of()](depth if decoding else (depth * q_num_heads), 1)), MutableAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]
Returns:
mask_status
mask_status[mask_t: MHAMask](self, mask: mask_t, kv_tile_start_row: UInt32) -> TileMaskStatus
Returns:
TileMaskStatus
exp_sum_qk_max_ptr
exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: UInt32) -> Tuple[UnsafePointer[Scalar[partition_t.accum_dtype]], UnsafePointer[Scalar[partition_t.accum_dtype]]]
Returns:
get_start_and_end_for_partitions
get_start_and_end_for_partitions[partition_t: MHAPartitionScheme, //, BN: Int](self, partition: partition_t) -> Tuple[UInt32, UInt32]
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!