Mojo struct
QueuedTileScheduler
@register_passable(trivial)
struct QueuedTileScheduler[tile_shape: UInt32, num_heads: UInt32, /, decoding: Bool, num_ctas: UInt32 = GPUInfo("H100", Vendor(2), "cuda", "hopper", 9, "sm_90a", 132, 32, 2048, 233472, 65536, 1024).sm_count, schedule: MHASchedule = MHASchedule(0)]
If decoding == False
, then num_heads
is q_num_heads
. If decoding == True
, then num_heads
is kv_num_heads
.
Fields
- gidx_ptr (
UnsafePointer[UInt32, address_space=AddressSpace(1)]
):
Implemented traits
AnyType
,
Copyable
,
DevicePassable
,
ImplicitlyCopyable
,
MHATileScheduler
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
device_type
alias device_type = QueuedTileScheduler[tile_shape, num_heads, decoding, num_ctas, schedule]
may_advance
alias may_advance = True
mha_schedule
alias mha_schedule = schedule
Methods
__init__
__init__(gidx_ptr: UnsafePointer[UInt32]) -> Self
get_current_work_info
get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo
Returns:
advance
advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization(1)](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]
The parameter func
must return a Bool
indicating whether the WorkInfo
arg is valid. This function returns whether the current idx corresponds to a valid WorkInfo
. Note that if MHASchedulerSynchronization
is NONE
, then we assume it is only called by thread_idx.x==0
.
Returns:
grid_dim
static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]
Returns:
initial_state
initial_state[ValidLengthType: OptionalPointer, //](self, ptr: UnsafePointer[UInt32, address_space=AddressSpace(3)], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState
Returns:
unsafe_seq_info
unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo
Returns:
get_type_name
static get_type_name() -> String
Gets the name of the host type (the one implementing this trait).
Returns:
String
: The host type's name.
get_device_type_name
static get_device_type_name() -> String
Gets device_type's name.
Returns:
String
: The device type's name.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!