Python module
layer
Layer
class max.pipelines.nn.layer.Layer
Base Layer class.
Currently, only functionality is for adding hooks to the call function of each layer to support testing, debugging or profiling.
LayerV2
class max.pipelines.nn.layer.LayerV2
(new) Base class for model layers with weight and device management.
This will be merged with the above class once all layers have been moved to V2.
layer_weights
property layer_weights*: dict[str, max.graph.weight.Weight]*
sublayers
property sublayers*: dict[str, max.pipelines.nn.layer.LayerV2]*
to()
to(*devices: DeviceRef, sharding_strategy: ShardingStrategy | None = None) → None
ShardingStrategy
class max.pipelines.nn.layer.ShardingStrategy(host_device: DeviceRef, shard_value: Callable[[Weight], tuple[max.graph.value.TensorValue, ...]])
Defines how to load and shard a weight onto multiple devices.
host_device
host_device*: DeviceRef*
shard_value
shard_value*: Callable[[Weight], tuple[max.graph.value.TensorValue, ...]]*
add_layer_hook()
max.pipelines.nn.layer.add_layer_hook(fn: Callable[[Layer, Tuple[Any, ...], Dict[str, Any], Any], Any]) → None
Adds a hook to call a function after each layer’s __call__.
The function will be passed four inputs: the layer, input_args, input_kwargs and outputs. The function can either return None or new outputs that will replace the layer returned outputs.
Note that input and outputs contain graph Values, which show limited information (like shape and dtype). You can still see the computed values if you include the Value in the graph.output op, or call value.print.
Example of printing debug inputs:
def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs
add_layer_hook(print_info)
def print_info(layer, args, kwargs, outputs):
print("Layer:", type(layer).__name__)
print("Input args:", args)
print("Input kwargs:", kwargs)
print("Outputs:", outputs)
return outputs
add_layer_hook(print_info)
clear_hooks()
max.pipelines.nn.layer.clear_hooks()
Remove all hooks.
recursive_named_layers()
max.pipelines.nn.layer.recursive_named_layers(parent: LayerV2) → Iterable[tuple[str, max.pipelines.nn.layer.LayerV2]]
Recursively walks through the layers and generates names.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!