Skip to main content

Python module

distributed_transformer

DistributedTransformer

class max.nn.transformer.distributed_transformer.DistributedTransformer(dim, n_heads, layers, norm, output, embedding, kv_params, kv_collection_constructor, devices, rope, return_logits=ReturnLogits.LAST_TOKEN, use_subgraphs=False, subgraph_layer_groups=None)

Transformer model consisting for TransformerBlock layers.

Parameters:

DistributedTransformerBlock

class max.nn.transformer.distributed_transformer.DistributedTransformerBlock(attention, mlp, attention_norm, mlp_norm, devices, distributed_gemm_config=None)

Stack of Attention, FeedForward, and RMSNorm layers.

Parameters:

ShardableCallable

class max.nn.transformer.distributed_transformer.ShardableCallable(*args, **kwargs)

distribute_value()

max.nn.transformer.distributed_transformer.distribute_value(v, devices)

Parameters:

devices (list[DeviceRef])

forward_sharded_layers()

max.nn.transformer.distributed_transformer.forward_sharded_layers(layers, xs)

Forward pass through sharded layers.

Parameters:

Returns:

List of output tensors from each layer

Raises:

AssertionError – If the number of layers and input tensors don’t match

Return type:

list[TensorValue]

take()

max.nn.transformer.distributed_transformer.take(it, n)

Return the next n items from it as a list.

Parameters:

Return type:

list[Value]

Was this page helpful?