Skip to main content

Mojo struct

TransientScheduler

@register_passable(trivial) struct TransientScheduler[tile_shape: UInt32, num_heads: UInt32]

Implemented traits

AnyType, Copyable, Defaultable, ImplicitlyCopyable, MHATileScheduler, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

may_advance

alias may_advance = False

mha_schedule

alias mha_schedule = MHASchedule(0)

Methods

__init__

__init__() -> Self

get_current_work_info

get_current_work_info(self) -> WorkInfo

Returns:

WorkInfo

get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo

Returns:

WorkInfo

advance

advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization(1)](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]

Returns:

OptionalReg

grid_dim

static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]

Returns:

Tuple

initial_state

initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, address_space=AddressSpace(3)], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState

Returns:

MHATileState

unsafe_seq_info

unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo

Returns:

SeqInfo

Was this page helpful?