Python module
layer
Layer
class max.nn.layer.Layer
Deprecated
Deprecated since version 25.2..
Base class for neural network components.
Use Module
instead.
Provides functionality for adding hooks to the call function of each layer to support testing, debugging or profiling.
LayerList
Stores a list of layers.
Can be used as a regular python list.
append()
append(layer: Layer)
extend()
extend(layer: Layer)
insert()
insert(i, layer: Layer)
sublayers
property sublayers*: dict[str, max.nn.layer.layer.Module]*
Module
class max.nn.layer.Module
Base class for model components with weight management.
Provides functionality to create custom layers and construct networks with automatic weight tracking.
The following example uses the Module
class to create custom layers and build a neural network:
from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef
class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())
def __call__(self, x):
return x @ self.weight.T
class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)
def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))
model = MLP()
print(model.state_dict()) # {"up.weight": Tensor([5, 10]), ...}
from max import nn
from max.dtype import DType
from max.graph import Weight, ops, DeviceRef
class Linear(nn.Module):
def __init__(self, in_dims, out_dims):
super().__init__()
self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU())
def __call__(self, x):
return x @ self.weight.T
class MLP(nn.Module):
def __init__(self):
self.up = Linear(5, 10)
self.gate = Linear(5, 10)
self.down = Linear(10, 5)
def __call__(self, x):
return self.down(ops.silu(self.gate(x)) + self.up(x))
model = MLP()
print(model.state_dict()) # {"up.weight": Tensor([5, 10]), ...}
Constructing a graph without Module
can result in name collisions
with the weights (in this example, there would be three weights with the
name Weight). With Module
, you can use state_dict()
or
load_state_dict()
to initialize or set the weights values, and finalize
the weight names to be unique within the model.
layer_weights
property layer_weights*: dict[str, max.graph.weight.Weight]*
load_state_dict()
load_state_dict(state_dict: Mapping[str, DLPackArray | ndarray | WeightData], *, override_quantization_encoding: bool = False, weight_alignment: int | None = None) → None
Sets the values of all weights in this model.
-
Parameters:
- state_dict – A map from weight name to a numpy array or
max.driver.Tensor
. - override_quantization_encoding – Whether to override the weight quantization based on the loaded value.
- weight_alignment – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state_dict must be aligned by the default dtype alignment.
- state_dict – A map from weight name to a numpy array or
-
Raises:
Error if any weight in the model is not present in the state dict. –
raw_state_dict()
raw_state_dict() → dict[str, max.graph.weight.Weight]
Returns all weights objects in the model.
Unlike state_dict
, this returns max.graph.Weight
objects instead of
the assigned values. Some parameters inside the Weight
can be
configured before a graph is built. Do not change these attributes after
building a graph:
-
Returns:
Map from weight name to the
max.graph.Weight
object.
set_shared_weight()
state_dict()
state_dict(auto_initialize: bool = True) → dict[str, Union[max._core_types.driver.DLPackArray, numpy.ndarray]]
Returns values of all weights in the model.
The values returned are the same as the values set in load_state_dict
.
If load_state_dict
has not been called and none of the weights have
values, then they are initialized to zero.
-
Parameters:
auto_initialize – Determines whether to initialize weights to zero if the weight value has not been loaded. If this is False, a ValueError is raised if an uninitialized weight is found.
-
Returns:
Map from weight name to the weight value (can be numpy array or
max.driver.Tensor
).
sublayers
property sublayers*: dict[str, max.nn.layer.layer.Module]*
add_layer_hook()
max.nn.layer.add_layer_hook(fn: Callable[[Layer, tuple[Any, ...], dict[str, Any], Any], Any]) → None
Adds a hook to call a function after each layer’s __call__
.
The function will be passed four inputs:
- layer
- input_args
- input_kwargs
- outputs
The function can either return None or new outputs that will replace the layer returned outputs.
Note that input and outputs contain graph Values, which show limited
information (like shape
and dtype
). You can still see the computed values
if you include the Value in the graph.ops.output
op, or call graph.ops.print
.
Example of printing debug inputs:
def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs
add_layer_hook(print_info)
def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs
add_layer_hook(print_info)
clear_hooks()
max.nn.layer.clear_hooks()
Remove all hooks.
recursive_named_layers()
max.nn.layer.recursive_named_layers(parent: Module, prefix: str = '') → Iterable[tuple[str, max.nn.layer.layer.Module]]
Recursively walks through the layers and generates names.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!