Skip to main content

Mojo struct

SplitKTileScheduler

@register_passable(trivial) struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode(0)]

Fields

  • prob_shape (IndexList[3]):
  • block_id_in_cluster (IndexList[2]):
  • blocks_per_problem (UInt32):
  • current_work_linear_idx (UInt32):
  • log_cluster_shape_major (UInt32):
  • log_cluster_shape_minor (UInt32):
  • cluster_blk_major (UInt32):
  • locks_ptr (UnsafePointer[Int32]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

k_tiles_per_output_tile

alias k_tiles_per_output_tile = ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))

k_tiles_per_split

alias k_tiles_per_split = splits.__rfloordiv__[DType.uint32, 1](SIMD[DType.uint32, 1](ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))))

log_cluster_size

alias log_cluster_size = log2_floor((cluster_shape.__getitem__[2, DType.int64, Int](0) * cluster_shape.__getitem__[2, DType.int64, Int](1)))

Methods

__init__

__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: UnsafePointer[NoneType]) -> Self

get_sm_num

get_sm_num(self) -> UInt32

Returns:

UInt32

get_problem_blocks_shape

static get_problem_blocks_shape(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]

Returns:

IndexList

initial_work_tile_info

initial_work_tile_info(mut self) -> WorkInfo

Returns:

WorkInfo

get_current_work_info

get_current_work_info(mut self) -> WorkInfo

Returns:

WorkInfo

get_worktile_m_n_idx

get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)

assign_work

assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)

get_k_start_and_linear_tile_id

get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32

Returns:

UInt32

fetch_next_work

fetch_next_work(mut self, mut work_tile_info: WorkInfo) -> WorkInfo

Returns:

WorkInfo

requires_reduction

requires_reduction(self, work_tile_info: WorkInfo) -> Bool

Returns:

Bool

advance_to_next_work

advance_to_next_work(mut self)

is_last_split

is_last_split(self, work_tile_info: WorkInfo) -> Bool

Returns:

Bool

get_grid_shape

static get_grid_shape(cluster_shape: IndexList[3], raster_order: RasterOrder = RasterOrder(0)) -> IndexList[3]

Returns:

IndexList

get_num_tiles

static get_num_tiles(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int

Returns:

Int

get_required_locks_buffer_size_bytes

static get_required_locks_buffer_size_bytes[accum_type: DType, num_consumer: UInt32](problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int

Returns:

Int

get_linear_idx_from_m_and_n

get_linear_idx_from_m_and_n(self, tile_m: UInt32, tile_n: UInt32) -> UInt32

Returns:

UInt32

output_tile_index

output_tile_index(self, work_tile_info: WorkInfo) -> UInt32

Returns:

UInt32

reduction

reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)

wait_eq

static wait_eq(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)

wait_lt

static wait_lt(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)

arrive_set

static arrive_set(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)

store_accumulator

store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)

reduce_add

reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)

Was this page helpful?