Skip to main content
Log in

Mojo struct

MHAConfig

@register_passable(trivial) struct MHAConfig

Fields

  • type (DType):
  • num_heads (UInt):
  • depth (UInt):
  • num_queries_per_block (UInt):
  • num_keys_per_block (UInt):
  • BK (UInt):
  • WM (UInt):
  • WN (UInt):
  • num_pipeline_stages (UInt):
  • k_group_size (UInt):
  • algorithm (FlashAttentionAlgorithm):

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, Movable, UnknownDestructibility, Writable

Methods

__init__

__init__(type: DType, num_heads: UInt, depth: UInt, num_queries_per_block: OptionalReg[UInt] = OptionalReg[UInt]({:i1 0, 1}), num_keys_per_block: OptionalReg[UInt] = OptionalReg[UInt]({:i1 0, 1}), BK: OptionalReg[UInt] = OptionalReg[UInt]({:i1 0, 1}), WM: OptionalReg[UInt] = OptionalReg[UInt]({:i1 0, 1}), WN: OptionalReg[UInt] = OptionalReg[UInt]({:i1 0, 1}), num_pipeline_stages: UInt = UInt(2 if _accelerator_arch().__contains__[::Bool,::Origin[$2]](__init__[__mlir_type.!kgen.string](":90")) else 4), k_group_size: UInt = UInt(1), algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm()) -> Self

block_m

block_m(self) -> UInt

block_n

block_n(self) -> UInt

block_k

block_k(self) -> UInt

warp_m

warp_m(self) -> UInt

warp_n

warp_n(self) -> UInt

num_warps_m

num_warps_m(self) -> UInt

num_warps_n

num_warps_n(self) -> UInt

num_consumer_threads

num_consumer_threads(self) -> UInt

num_producer_threads

num_producer_threads[producer_consumer_kernel: Bool = False](self) -> UInt

num_threads

num_threads[producer_consumer_kernel: Bool = False](self) -> UInt

q_smem_size

q_smem_size(self, sm_90: Bool = False) -> UInt

kv_smem_size

kv_smem_size(self, sm_90: Bool = False) -> UInt

k_smem_size

k_smem_size(self, sm_90: Bool = False) -> UInt

v_smem_size

v_smem_size(self, sm_90: Bool = False) -> UInt

p_smem_size

p_smem_size(self) -> UInt

warp_scratch_smem_size

warp_scratch_smem_size(self) -> UInt

shared_mem_bytes

shared_mem_bytes[shared_kv: Bool = False, sm_90: Bool = False](self) -> UInt

__str__

__str__(self) -> String

write_to

write_to[W: Writer](self, mut writer: W)