Skip to main content
Log in

Python module

distributed_transformer

DistributedTransformer

class max.nn.transformer.distributed_transformer.DistributedTransformer(dim: int, n_heads: int, layers: list[max.nn.transformer.distributed_transformer.DistributedTransformerBlock], norm: RMSNorm | LayerNorm, output: Linear, embedding: VocabParallelEmbedding, kv_params: KVCacheParams, kv_collection_constructor: FetchContinuousBatchingKVCacheCollection | FetchPagedKVCacheCollection, devices: list[max.graph.type.DeviceRef], return_logits: ReturnLogits = ReturnLogits.LAST_TOKEN)

Transformer model consisting for TransformerBlock layers.

DistributedTransformerBlock

class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention: Module, mlp: Module, attention_norm: DistributedRMSNorm, mlp_norm: DistributedRMSNorm, devices: list[max.graph.type.DeviceRef], use_subgraph: bool = False)

Stack of Attention, FeedForward, and RMSNorm layers.

build_subgraph()

build_subgraph(name: str) → Module

distribute_value()

max.nn.transformer.distributed_transformer.distribute_value(v, devices: list[max.graph.type.DeviceRef])