Mojo trait
MHAMask
The MHAMask trait describes masks for MHA kernels, such as the causal mask.
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
apply_log2e_after_mask
alias apply_log2e_after_mask
Does the mask require log2e
to be applied after the mask, or can it be fused with the scaling?
mask_out_of_bound
alias mask_out_of_bound
mask_safe_out_of_bounds
alias mask_safe_out_of_bounds
Is the mask safe to read out of bounds?
Methods
mask
mask[type: DType, width: Int, //, *, element_type: DType = uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[type, width]) -> SIMD[type, width]
Return mask vector at given coordinates.
Arguments:
coord is (seq_id, head, q_idx, k_idx)
score_vec is at coord
of the score matrix
The functor could capture an mask tensor and add to the score e.g. Replit.
status
status[*, element_type: DType = uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus
Given a tile's index range, return its masking status.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!