Python module
module_v3
Module implementation using eager tensors.
Embedding
class max.nn.module_v3.Embedding(vocab_size, *, dim=None, dims=None)
A vector embedding.
An embedding can be thought of as a lookup table for vectors by index. Given an input tensor of indices into the embedding, the result of the embedding lookup is a tensor of the same shape, but with each index replaced by the value of the vector in that location in the embedding table.
The common case for embeddings is a 1-dimensional embedding:
from max.dtype import DType
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding
embedding = Embedding(vocab_size=1000, dim=128)
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 128]
However they just as easily support multi-dimensional embeddings:
from max.dtype import DType
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Embedding
embedding = Embedding(vocab_size=1000, dims=[16, 128])
tokens = Tensor.ones([10], dtype=DType.uint64)
embedded = embedding(tokens)
assert embedded.shape == [10, 16, 128]
Creates a randomly initialized embedding of the specified size.
-
Parameters:
-
- vocab_size (DimLike) – The number of elements in the lookup table. Indices outside the range of [0, index_size) are illegal in the resulting embedding operation.
- dim (DimLike | None) – The embedding dimension if there is exactly one. Equivalent to dims=[dim].
- dims (ShapeLike | None) – For specifying multi-dimensional embeddings. The shape of the vectors in the embedding.
dim
property dim: Dim
The dimension of the vectors in the embedding (for a 1d embedding).
Raises: For 0- or >1-dimensional embeddings.
dims
The dimensions of the vectors in the embedding.
vocab_size
property vocab_size: Dim
The vocab size of the embedding.
Indices outside the range of [0, index_size) are illegal.
weight
weight: Tensor
Linear
class max.nn.module_v3.Linear(in_dim, out_dim, *, bias=True)
A unary linear transformation over an input tensor.
Linear is defined as f(x) = x @ W.T + B where W is the weight tensor and B is an optional bias tensor.
If W is not square then the transformation represents a dimensionality change. By convention the weight tensor is stored transposed.
from max.nn.module_v3 import Linear
from max.experimental.tensor import Tensor
model = Linear(5, 10)
assert dict(model.parameters) == {
"weight": model.weight, "bias": model.bias
}
result = model(Tensor.ones([5]))
assert result.shape == [10]
Constructs a random linear transformation of the given dimensions.
-
Parameters:
bias
bias: Tensor | Literal[0]
The bias Tensor
for the linear transformation (or 0 if bias is disabled).
in_dim
property in_dim: Dim
The input dimension for the transformation.
out_dim
property out_dim: Dim
The output dimension for the transformation.
weight
weight: Tensor
The weight Tensor
for the linear transformation.
Module
class max.nn.module_v3.Module
The core unit of composition for modeling in MAX.
Informally, a Module
is a container class. It can contain
other Module
instances, tensors (the Module
’s “local parameters”)
or other arbitrary Python data.
A Module
also has a __call__()
which applies that Module
to
some input. In the simplest case this is a function from one tensor
to another tensor.
Formally modules form a tree, and subtrees of modules can be manipulated
directly. A Module
may also be thought of as a closure, where the parameters
form the data of the closure and __call__()
is the application of the closure.
Terminology:
- A “child” of a
Module
is a sub-Module
stored directly on thatModule
. - A “descendant” of a
Module
is one of its children, or one of their descendants. - A “parameter” is a tensor storing data on the
Module
or one of its descendants. - The “qualified path” of a descendant is a period-separated string
of the names of the child module attributes which lead to that
descendant module, for instance
child.sub.last
. - The “qualified path” of a parameter is the qualified path of the
descendant directly holding that parameter, followed by a final
path component for the attribute name of the tensor.
For instance
weight
for a local parameter, orchild.sub.last.weight
for a descendant’s parameter.
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor | int = 0
def __call__(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(Tensor.zeros([5, 4]))
print(linear)
print(linear(Tensor.constant([1, 2, 3, 4])))
apply_to_local_parameters()
apply_to_local_parameters(f)
Applies a transformation to each local parameter tensor on the Module
.
The transformation is applied in-place, updating the module’s values. It will not be applied to descendant’s parameters.
For example:
from max.driver import Accelerator
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model.apply_to_parameters(lambda _, t: t.to(Accelerator()))
-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
The transformation to apply to each local parameter. The transformation takes two arguments, a name and a tensor:
- The name is the attribute name of the parameter on the module.
- The tensor is the current value of that parameter.
The return value of this function is the new value that will replace the value at that name.
-
Return type:
-
None
apply_to_parameters()
apply_to_parameters(f)
Applies a transformation to all parameters in the module hierarchy.
This method traverses the module tree and applies the transformation function to each parameter in-place, updating both the current module’s parameters and all nested sub-module parameters. The transformation receives the parameter’s qualified name (dot-separated path) and current tensor value.
Transfer all parameters to accelerator:
from max.driver import Accelerator
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Module, module_dataclass, Linear
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def __call__(self, x: Tensor) -> Tensor:
return self.fc2(self.fc1(x))
model = MLP(
fc1=Linear(10, 20),
fc2=Linear(20, 5)
)
model.apply_to_parameters(lambda name, t: t.to(Accelerator()))
-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
Transformation function taking
(name, tensor)
and returning the transformed tensor. Parameters:name
(str
): Qualified dot-separated path of the parameter (e.g.,"fc1.weight"
,"encoder.layer2.bias"
)tensor
(Tensor
): Current value of the parameter
Returns the new tensor value to replace the parameter.
-
Return type:
-
None
children
Iterates over the direct child modules of the Module
.
-
Yields:
-
(name, module)
pairs, wherename
is the attribute name of the child on the module.
compile()
compile(*input_types)
Compiles the module to an optimized executable through graph tracing.
This method performs symbolic tracing of the module’s __call__
method
to construct a MAX Graph
, which is then compiled and optimized for
efficient execution on CPU, GPU, or other accelerators.
The compilation process:
- Creates symbolic
Tensor
instances based on provided type specifications - Executes
__call__
with symbolic tensors to record operations - Constructs a
Graph
representing the computation - Includes all module parameters as weights in the graph
- Compiles and optimizes the graph for target hardware
- Returns an executable function with the same signature as
__call__
The input type specifications must match the signature of __call__
.
Use positional arguments for positional parameters.
Basic compilation with fixed shapes:
from max.dtype import DType
from max.experimental.tensor import Tensor, TensorType, defaults
from max.nn.module_v3 import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor
def __call__(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
linear = Linear(
weight=Tensor.zeros([10, 5]),
bias=Tensor.zeros([10])
)
# Compile with fixed input shape
_, device = defaults()
input_type = TensorType(DType.float32, [3, 5], device=device)
model = linear.compile(input_type)
# Execute compiled model
input_data = Tensor.ones([3, 5], dtype=DType.float32)
result = model(input_data)
print(result)
-
Parameters:
-
*input_types (Type[Any]) – Type specifications for each positional argument to
__call__
. Must match the number and order of arguments. Each should be amax.graph.Type
(typicallyTensorType
) describing the shape and dtype. -
Returns:
-
Callable[…, Any] A compiled executable function with the same signature as
__call__
. This function runs the optimized graph and returns results with the same structure as__call__
(singleTensor
or tuple of tensors). -
Raises:
-
- TypeError – If input types don’t match
__call__
signature or if operations in__call__
cannot be traced. - RuntimeError – If graph construction fails due to incompatible operations or parameter access issues.
- TypeError – If input types don’t match
-
Return type:
descendants
Iterates over the Module
’s descendant modules.
-
Yields:
-
(name, module)
pairs, wherename
is the qualified path of the descendant with respect to the module.
load_state()
load_state(lookup)
Replaces each parameter in the module and its descendants.
The transformation is applied in-place, updating the module’s values and those of its descendants.
For example, if we have a model with two parameters, weight
and
bias
, we can load the state of the model from a dictionary with the
following code:
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Linear
model = Linear(2, 3)
weights = {
"weight": Tensor.zeros([3, 2]),
"bias": Tensor.zeros([3]),
}
model.load_state(weights.__getitem__)
The lookup is defined as a function rather than a dictionary, allowing for functional remapping of names during this process to account for differences in common weight naming and storage conventions.
For instance, certain representations may not store weights as transposed, or may need to be quantized, or split out from a shared qkv block, or may just have slightly different names or paths.
This can also be used for instance to provide a default value for initializing LoRA weights.
-
Parameters:
-
lookup (Callable[[str], DLPackArray]) –
The lookup function for each parameter:
- The argument to the lookup function is the qualified name
of the parameter with respect to the module on which
load_state()
was called. - The return value of this function is the new value that will replace the value at that name in the module tree.
- The argument to the lookup function is the qualified name
of the parameter with respect to the module on which
load_state_dict()
load_state_dict(state, strict=True)
Loads parameter values from a dictionary into the module hierarchy.
This method updates all module parameters in-place by loading values from
the provided state dictionary. The dictionary maps qualified parameter names
(dot-separated paths like "fc1.weight"
) to tensor values.
The strict
mode (default) ensures all weights in the dictionary are
actually used, catching errors from mismatched architectures or incorrect
weight names.
For example, the following loads weights from a dictionary into a model:
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Module, module_dataclass
@module_dataclass
class Linear(Module):
weight: Tensor
bias: Tensor
def __call__(self, x: Tensor) -> Tensor:
return x @ self.weight.T + self.bias
model = Linear(
weight=Tensor.zeros([10, 5]),
bias=Tensor.zeros([10])
)
# Load weights from dictionary
weights = {
"weight": Tensor.zeros([10, 5]),
"bias": Tensor.zeros([10]),
}
model.load_state(weights.__getitem__)
-
Parameters:
-
- state (Mapping[str, DLPackArray]) – Dictionary mapping qualified parameter names to tensor values.
Keys should match the names from
Module.parameters
property. Values should be DLPack-compatible arrays orTensor
objects. - strict (bool) – If
True
(default), verify that all keys instate
are used (i.e., match actual parameters). IfFalse
, silently ignore extra keys that don’t match any parameters.
- state (Mapping[str, DLPackArray]) – Dictionary mapping qualified parameter names to tensor values.
Keys should match the names from
-
Raises:
-
- ValueError – If
strict=True
and some weights instate
don’t match any model parameters (indicates architecture mismatch or incorrect weight names). - KeyError – If a required parameter name in the model is missing from
state
(regardless ofstrict
setting).
- ValueError – If
-
Return type:
-
None
local_parameters
Iterates over the local parameters of the Module
.
-
Yields:
-
(name, tensor)
pairs, wherename
is the attribute name of the tensor on the module.
map_parameters()
map_parameters(f)
Creates a new Module
with its parameters transformed by the function.
The transformation is functional rather than in-place. The module is deep-copied; its descendants are also replaced via the same transform without affecting the original module.
For example:
from max.driver import Accelerator
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model_on_gpu = model.map_parameters(lambda _, t: t.to(Accelerator()))
-
Parameters:
-
f (Callable[[str, Tensor], Tensor]) –
The transformation to apply to each parameter. The transformation takes two arguments, a name and a tensor:
- The name is the qualified name of the parameter
with respect to the module on which
map_parameters()
was called. - The tensor is the current value of that parameter.
The return value of this function is the new value that will replace the value at that name in the module tree.
- The name is the qualified name of the parameter
with respect to the module on which
-
Returns:
-
A new module tree of the same type resulting from mapping the transformation over all model parameters.
-
Return type:
parameters
Iterates over all parameters in this module and its sub-modules.
This property performs a depth-first traversal of the module hierarchy,
yielding each parameter tensor with its qualified name. The qualified name
uses dot-notation to represent the module tree structure (e.g.,
"encoder.layer1.weight"
).
Parameters are yielded in depth-first order: first the current module’s direct parameters, then recursively each sub-module’s parameters.
Counting total parameters:
from max.experimental.tensor import Tensor
from max.nn.module_v3 import Module, module_dataclass
from max.nn.module_v3 import Linear
@module_dataclass
class MLP(Module):
fc1: Linear
fc2: Linear
def __call__(self, x: Tensor) -> Tensor:
return self.fc2(self.fc1(x))
model = MLP(
fc1=Linear(10, 20),
fc2=Linear(20, 5)
)
# Count parameters
total_params = sum(
param.num_elements()
for name, param in model.parameters
)
print(f"Total parameters: {total_params}")
-
Yields:
-
(name, parameter)
tuples wherename
is the dot-separated qualified path of the parameter andparameter
is theTensor
.
to()
to(device)
Updates the module’s parameters, transferring them to the specified device.
from max.driver import CPU
from max.nn.module_v3 import Linear
model = Linear(2, 3)
model.to(CPU())
Sequential
class max.nn.module_v3.Sequential(*modules)
A Module
subclass which holds a sequence of unary modules.
A unary Module
is one whose __call__()
method has the signature:
def __call__(self, x: Tensor) -> Tensor: ...
Sequential
is itself a unary Module
. Its __call__()
method
computes the result of applying each of its child modules
in sequence to its input.
For example, this will apply a linear transformation up to a dimension of 10, apply a LayerNorm, and then apply a final linear transformation to reduce back to the input dimension of 5:
from max.experimental import Tensor
from max.nn.module_v3 import LayerNorm, Linear, Sequential
model = Sequential(
Linear(5, 10),
LayerNorm(10),
Linear(10, 5),
)
result = model(Tensor.ones([5]))
assert result.shape == [5]
Constructs a sequential from a sequence of modules.
Following PyTorch, Sequential
takes its inputs as a variadic
rather than an iterable. Use the splat operator (*seq
) to make
a Sequential
from an iterable.
For example:
from max.nn.module_v3 import Linear, Sequential
hidden_dims = [5, 10, 15, 20]
model = Sequential(*(
Linear(in_dim, out_dim) for in_dim, out_dim in
zip(hidden_dims, hidden_dims[1:])
))
-
Parameters:
-
modules (Module) – The sequence of contained modules in the order of desired application.
module_dataclass()
max.nn.module_v3.module_dataclass(cls=None, /, *, repr=False, **kwargs)
Converts a class into a MAX module with automatic parameter tracking.
This decorator enables a regular Python class to function as a Module
,
providing automatic discovery and registration of parameters (Tensor fields)
and nested modules. The decorated class gains all capabilities of Module
,
including parameter iteration, graph compilation via Module.compile()
,
and hierarchical module composition.
The decorator applies Python’s @dataclass
decorator internally while
preserving Module
’s specialized __repr__
method for better
debugging experience when printing module structures.
-
Parameters:
-
- cls (type[Module] | None) – The class to decorate. Must define a
__call__
method. WhenNone
, returns a decorator function (supports using@module_dataclass
with or without parentheses). - repr (bool) – If
True
, use dataclass’s default__repr__
instead ofModule
’s rich representation. Defaults toFalse
. - **kwargs – Additional keyword arguments forwarded to Python’s
@dataclass
decorator (e.g.,frozen
,eq
).
- cls (type[Module] | None) – The class to decorate. Must define a
-
Returns:
-
The decorated class as a
Module
subclass with automatic parameter tracking and graph compilation capabilities. Whencls
isNone
, returns a decorator function.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!