Skip to main content
Log in

Mojo struct

WGMMADescriptor

@register_passable(trivial) struct WGMMADescriptor[dtype: DType]

Descriptor for shared memory operands used in warp group matrix multiply operations.

This struct represents a descriptor that encodes information about shared memory layout and access patterns for warp group matrix multiply operations. The descriptor contains the following bit fields:

  • Start address (14 bits): Base address in shared memory.
  • Leading byte offset (14 bits): Leading dimension stride in bytes.
  • Stride byte offset (14 bits): Stride dimension offset in bytes.
  • Base offset (3 bits): Additional offset.
  • Swizzle mode (2 bits): Memory access pattern.

The bit layout is: +----------+----+------------+----+------------+----+-----+----------+-----+ | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| +----------+----+------------+----+------------+----+-----+----------+-----+ | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| +----------+----+------------+----+------------+----+-----+----------+-----+ | BaseAddr | 0 |LeadingDim | 0 | Stride | 0 |Offst| 0 |Swzle| +----------+----+------------+----+------------+----+-----+----------+-----+

See: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor

Parameters

  • dtype (DType): The data type of the shared memory operand. This affects memory alignment and access patterns for the descriptor.

Fields

  • desc (SIMD[int64, 1]): The 64-bit descriptor value that encodes shared memory layout information. This field stores the complete descriptor with all bit fields packed into a single 64-bit integer:

    • Bits 0-13: Base address in shared memory (14 bits)
    • Bits 16-29: Leading dimension stride in bytes (14 bits)
    • Bits 32-45: Stride dimension offset in bytes (14 bits)
    • Bits 49-51: Base offset (3 bits)
    • Bits 62-63: Swizzle mode for memory access pattern (2 bits)

    The descriptor is used by NVIDIA Hopper architecture's warp group matrix multiply instructions to efficiently access shared memory with the appropriate layout and access patterns.

Implemented traits

AnyType, UnknownDestructibility

Methods

__init__

@implicit __init__(val: SIMD[int64, 1]) -> Self

Initialize descriptor with raw 64-bit value.

This constructor allows creating a descriptor directly from a 64-bit integer that already contains the properly formatted bit fields for the descriptor.

The implicit attribute enables automatic conversion from Int64 to WGMMADescriptor.

Args:

  • val (SIMD[int64, 1]): A 64-bit integer containing the complete descriptor bit layout.

__add__

__add__(self, offset: Int) -> Self

Add offset to descriptor's base address.

Args:

  • offset (Int): Byte offset to add to base address.

Returns:

New descriptor with updated base address.

__iadd__

__iadd__(mut self, offset: Int)

Add offset to descriptor's base address in-place.

Args:

  • offset (Int): Byte offset to add to base address.

create

static create[stride_byte_offset: Int, leading_byte_offset: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!pop.int_literal](0))](smem_ptr: UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]) -> Self

Create a descriptor for shared memory operand.

Parameters:

  • stride_byte_offset (Int): Stride dimension offset in bytes.
  • leading_byte_offset (Int): Leading dimension stride in bytes.
  • swizzle_mode (TensorMapSwizzle): Memory access pattern mode.

Args:

  • smem_ptr (UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]): Pointer to shared memory operand.

Returns:

Initialized descriptor for the shared memory operand.