Python module
distributed_transformer
DistributedTransformer
class max.nn.transformer.distributed_transformer.DistributedTransformer(dim: int, n_heads: int, layers: list[max.nn.transformer.distributed_transformer.DistributedTransformerBlock], norm: RMSNormV2 | LayerNormV2, output: LinearV2, embedding: VocabParallelEmbedding, kv_params: KVCacheParams, kv_collection_constructor: FetchContinuousBatchingKVCacheCollection | FetchPagedKVCacheCollection | FetchPagedKVCacheCollectionFA3Fallback, devices: list[max.graph.type.DeviceRef], all_logits: bool = False)
Transformer model consisting for TransformerBlock layers.
DistributedTransformerBlock
class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention: Module, mlp: Module, attention_norm: DistributedRMSNorm, mlp_norm: DistributedRMSNorm, devices: list[max.graph.type.DeviceRef])
Stack of Attention, FeedForward, and RMSNorm layers.
distribute_value()
max.nn.transformer.distributed_transformer.distribute_value(v, devices: list[max.graph.type.DeviceRef])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!