Skip to main content
Log in

Mojo struct

MHATileSummary

@register_passable(trivial) struct MHATileSummary

Fields

  • batch_size (SIMD[uint32, 1]):
  • max_num_prompt_tiles (SIMD[uint32, 1]):
  • valid_length (NDBuffer[uint32, 1, MutableAnyOrigin]):
  • max_seq_len (SIMD[uint32, 1]):

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, Movable, UnknownDestructibility

Methods

__init__

__init__(batch_size: SIMD[uint32, 1], max_num_prompt_tiles: SIMD[uint32, 1], valid_length: NDBuffer[uint32, 1, MutableAnyOrigin], max_seq_len: SIMD[uint32, 1]) -> Self

get_current_work_info

get_current_work_info[tile_shape: SIMD[uint32, 1], num_heads: SIMD[uint32, 1], schedule: MHASchedule](self, idx: SIMD[uint32, 1]) -> WorkInfo

get_current_work_info[tile_shape: SIMD[uint32, 1], num_heads: SIMD[uint32, 1], schedule: MHASchedule](self, idx: MHATileState) -> WorkInfo

unsafe_get_current_work_info

unsafe_get_current_work_info[tile_shape: SIMD[uint32, 1], num_heads: SIMD[uint32, 1], schedule: MHASchedule](self, idx: SIMD[uint32, 1]) -> WorkInfo

max_idx

max_idx(self, num_heads: SIMD[uint32, 1]) -> SIMD[uint32, 1]

grid_dim

static grid_dim[num_heads: SIMD[uint32, 1]](max_num_prompt_tiles: SIMD[uint32, 1], batch_size: SIMD[uint32, 1]) -> Tuple[Int, Int, Int]

seq_info

seq_info[ragged: Bool](self, work: WorkInfo) -> SeqInfo

unsafe_seq_info

unsafe_seq_info[tile_shape: SIMD[uint32, 1], num_heads: SIMD[uint32, 1], ragged: Bool, schedule: MHASchedule](self, idx: SIMD[uint32, 1]) -> SeqInfo

unsafe_seq_info[tile_shape: SIMD[uint32, 1], num_heads: SIMD[uint32, 1], ragged: Bool, schedule: MHASchedule](self, state: MHATileState) -> SeqInfo