Python class
Graph
Graph
class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], context=None, kernel_library=None, module=None, **kwargs)
Represents a single MAX graph.
A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions.
When you instantiate a graph, you must specify the input shapes as one or
more TensorType
values. Then, build a sequence of ops and set the
graph output with output()
. For example:
from dataclasses import dataclass
import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops
@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray
def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
return ops.matmul(x, weight_tensor) + bias_tensor
linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to build a graph with MAX Graph.
When creating a graph, a global sequence of chains is initialized and stored in Graph._current_chain. Every side-effecting op, e.g. buffer_load, store_buffer, load_slice_buffer, store_slice_buffer, will use the current chain to perform the op and and update Graph._current_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.
-
Parameters:
-
- name (str) – A name for the graph.
- forward (Callable[..., None | Value[Any] | Iterable[Value[Any]]] | None) – The sequence of graph ops for the forward pass (inference).
- input_types (Iterable[Type[Any]]) – The data type(s) for the input tensor(s).
- path (Optional[Path]) – The path to a saved graph (internal use only).
- custom_extensions (list[Path]) – The extensions to load for the model. Supports paths to .mojopkg or .mojo sources with custom ops.
- context (Optional[mlir.Context])
- kernel_library (Optional[KernelLibrary])
- module (Optional[mlir.Module])
add_subgraph()
add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[])
Creates and adds a subgraph to the current graph.
Creates a new Graph
instance configured as a subgraph of the current
graph. The subgraph inherits the parent graph’s MLIR context, module, and
symbolic parameters. A chain type is automatically appended to the input
types to enable proper operation sequencing within the subgraph.
The created subgraph is marked with special MLIR attributes to identify it as a subgraph and is registered in the parent graph’s subgraph registry.
-
Parameters:
-
- name (str) – The name identifier for the subgraph.
- forward (Callable[[...], None | Value[Any] | Iterable[Value[Any]]] | None) – The optional callable that defines the sequence of operations for the subgraph’s forward pass. If provided, the subgraph will be built immediately using this callable.
- input_types (Iterable[Type[Any]]) – The data types for the subgraph’s input tensors. A chain type will be automatically added to these input types.
- path (Path | None) – The optional path to a saved subgraph definition to load from disk instead of creating a new one.
- custom_extensions (list[Path]) – The list of paths to custom operation libraries
to load for the subgraph. Supports
.mojopkg
files and Mojo source directories.
-
Return type:
add_weight()
add_weight(weight, force_initial_weight_on_host=True)
Adds a weight to the graph.
If the weight is in the graph already, return the existing value.
-
Parameters:
-
Returns:
-
A
TensorValue
that contains this weight. -
Raises:
-
ValueError – If a weight with the same name already exists in the graph.
-
Return type:
always_ready_chain
property always_ready_chain: _ChainValue
A graph-global, immutable chain that is always ready.
Created once per graph and never advanced/merged by the graph itself. Use it for operations that are safe to schedule without threading per-device ordering (e.g., host→device transfers for staging).
current
current
device_chains
device_chains: defaultdict[DeviceRef, _ChainValue]
inputs
The input values of the graph.
kernel_libraries_paths
Returns the list of extra kernel libraries paths for the custom ops.
merge_device_chains()
merge_device_chains()
Joins device execution to a common point by merging chains.
-
Return type:
-
None
output()
output(*outputs)
Sets the output nodes of the Graph
.
output_types
View of the types of the graph output terminator.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!