Python module
attention.mask_config
Mask configuration for attention.
AttentionMaskVariant
class max.nn.attention.mask_config.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_MASK
CAUSAL_MASK = 'causal'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = 'chunked_causal'
NULL_MASK
NULL_MASK = 'null'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = 'sliding_window_causal'
TENSOR_MASK
TENSOR_MASK = 'tensor_mask'
MHAMaskConfig
class max.nn.attention.mask_config.MHAMaskConfig(attention_mask_variant: 'AttentionMaskVariant', positional_encoding_variant: 'PositionalEncodingVariant')
-
Parameters:
-
- attention_mask_variant (AttentionMaskVariant)
- positional_encoding_variant (PositionalEncodingVariant)
attention_mask_variant
attention_mask_variant: AttentionMaskVariant
positional_encoding_variant
positional_encoding_variant: PositionalEncodingVariant
MHAMaskVariant
class max.nn.attention.mask_config.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
CAUSAL_ALIBI_MASK
CAUSAL_ALIBI_MASK = '1'
CAUSAL_MASK
CAUSAL_MASK = '0'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = '3'
NULL_MASK
NULL_MASK = '2'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = '4'
PositionalEncodingVariant
class max.nn.attention.mask_config.PositionalEncodingVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
ALIBI_POS
ALIBI_POS = 'alibi_pos'
NO_POS
NO_POS = 'no_pos'
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!