Skip to main content

Python module

attention.mask_config

Mask configuration for attention.

AttentionMaskVariant

class max.nn.attention.mask_config.AttentionMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

CAUSAL_MASK

CAUSAL_MASK = 'causal'

CHUNKED_CAUSAL_MASK

CHUNKED_CAUSAL_MASK = 'chunked_causal'

NULL_MASK

NULL_MASK = 'null'

SLIDING_WINDOW_CAUSAL_MASK

SLIDING_WINDOW_CAUSAL_MASK = 'sliding_window_causal'

TENSOR_MASK

TENSOR_MASK = 'tensor_mask'

MHAMaskConfig

class max.nn.attention.mask_config.MHAMaskConfig(attention_mask_variant: 'AttentionMaskVariant', positional_encoding_variant: 'PositionalEncodingVariant')

Parameters:

attention_mask_variant

attention_mask_variant: AttentionMaskVariant

positional_encoding_variant

positional_encoding_variant: PositionalEncodingVariant

MHAMaskVariant

class max.nn.attention.mask_config.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

CAUSAL_ALIBI_MASK

CAUSAL_ALIBI_MASK = '1'

CAUSAL_MASK

CAUSAL_MASK = '0'

CHUNKED_CAUSAL_MASK

CHUNKED_CAUSAL_MASK = '3'

NULL_MASK

NULL_MASK = '2'

SLIDING_WINDOW_CAUSAL_MASK

SLIDING_WINDOW_CAUSAL_MASK = '4'

PositionalEncodingVariant

class max.nn.attention.mask_config.PositionalEncodingVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

ALIBI_POS

ALIBI_POS = 'alibi_pos'

NO_POS

NO_POS = 'no_pos'

Was this page helpful?