Mojo struct
SplitKTileScheduler
@register_passable(trivial)
struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode(0)]
Fields
- prob_shape (
IndexList[3]
): - block_id_in_cluster (
IndexList[2]
): - blocks_per_problem (
UInt32
): - current_work_linear_idx (
UInt32
): - log_cluster_shape_major (
UInt32
): - log_cluster_shape_minor (
UInt32
): - cluster_blk_major (
UInt32
): - locks_ptr (
UnsafePointer[Int32]
):
Implemented traits
AnyType
,
Copyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
k_tiles_per_output_tile
alias k_tiles_per_output_tile = ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))
k_tiles_per_split
alias k_tiles_per_split = splits.__rfloordiv__[DType.uint32, 1](SIMD[DType.uint32, 1](ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))))
log_cluster_size
alias log_cluster_size = log2_floor((cluster_shape.__getitem__[2, DType.int64, Int](0) * cluster_shape.__getitem__[2, DType.int64, Int](1)))
Methods
__init__
__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: UnsafePointer[NoneType]) -> Self
get_sm_num
get_problem_blocks_shape
static get_problem_blocks_shape(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]
Returns:
initial_work_tile_info
get_current_work_info
get_worktile_m_n_idx
get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)
assign_work
assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)
get_k_start_and_linear_tile_id
get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32
Returns:
fetch_next_work
requires_reduction
advance_to_next_work
advance_to_next_work(mut self)
is_last_split
get_grid_shape
static get_grid_shape(cluster_shape: IndexList[3], raster_order: RasterOrder = RasterOrder(0)) -> IndexList[3]
Returns:
get_num_tiles
static get_num_tiles(problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int
Returns:
get_required_locks_buffer_size_bytes
static get_required_locks_buffer_size_bytes[accum_type: DType, num_consumer: UInt32](problem_shape: IndexList[3], tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int
Returns:
get_linear_idx_from_m_and_n
output_tile_index
reduction
reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)
wait_eq
static wait_eq(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)
wait_lt
static wait_lt(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)
arrive_set
static arrive_set(lock_ptr: UnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)
store_accumulator
store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)
reduce_add
reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutableAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutableAnyOrigin, address_space=AddressSpace(5)], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!