Python module
distributed_transformer
DistributedTransformer
class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, devices, rope, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False, subgraph_layer_groups=None)
Transformer model consisting for TransformerBlock layers.
-
Parameters:
-
- dim (int)
- n_heads (int)
- layers (list[DistributedTransformerBlock])
- norm (ShardableCallable)
- output (ColumnParallelLinear)
- embedding (VocabParallelEmbedding)
- kv_params (KVCacheParams)
- kv_collection_constructor (FetchContinuousBatchingKVCacheCollection | FetchPagedKVCacheCollection)
- devices (list[DeviceRef])
- rope (RotaryEmbedding)
- return_logits (ReturnLogits)
- use_subgraphs (bool)
- subgraph_layer_groups (list[list[int]] | None)
DistributedTransformerBlock
class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices, distributed_gemm_config=None)
Stack of Attention, FeedForward, and RMSNorm layers.
-
Parameters:
-
- attention (Module)
- mlp (ShardableCallable)
- attention_norm (ShardableCallable)
- mlp_norm (ShardableCallable)
- devices (list[DeviceRef])
- distributed_gemm_config (DistributedGemmConfig | None)
ShardableCallable
class max.nn.transformer.distributed_transformer.ShardableCallable(*args, **kwargs)
distribute_value()
max.nn.transformer.distributed_transformer.distribute_value(v, devices)
forward_sharded_layers()
max.nn.transformer.distributed_transformer.forward_sharded_layers(layers, xs)
Forward pass through sharded layers.
-
Parameters:
-
- layers (Sequence[Callable[[TensorValue], TensorValue]]) – Sequence of callable layers that return TensorValue
- xs (Sequence[TensorValue]) – Input tensors, one per layer
-
Returns:
-
List of output tensors from each layer
-
Raises:
-
AssertionError – If the number of layers and input tensors don’t match
-
Return type:
take()
max.nn.transformer.distributed_transformer.take(it, n)
Return the next n items from it as a list.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!