Skip to main content
Log in

Python module

distributed_transformer

DistributedTransformer

class max.nn.transformer.distributed_transformer.DistributedTransformer(dim: int, n_heads: int, layers: list[max.nn.transformer.distributed_transformer.DistributedTransformerBlock], norm: RMSNormV2 | LayerNormV2, output: LinearV2, embedding: VocabParallelEmbedding, kv_params: KVCacheParams, kv_collection_constructor: FetchContinuousBatchingKVCacheCollection | FetchPagedKVCacheCollection | FetchPagedKVCacheCollectionFA3Fallback, devices: list[max.graph.type.DeviceRef], all_logits: bool = False)

Transformer model consisting for TransformerBlock layers.

DistributedTransformerBlock

class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention: Module, mlp: Module, attention_norm: DistributedRMSNorm, mlp_norm: DistributedRMSNorm, devices: list[max.graph.type.DeviceRef])

Stack of Attention, FeedForward, and RMSNorm layers.

distribute_value()

max.nn.transformer.distributed_transformer.distribute_value(v, devices: list[max.graph.type.DeviceRef])