Skip to main content

Python class

Graph

Graph

class max.graph.Graph(name, forward=None, input_types=(), path=None, *args, custom_extensions=[], context=None, kernel_library=None, module=None, **kwargs)

Represents a single MAX graph.

A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions.

When you instantiate a graph, you must specify the input shapes as one or more TensorType values. Then, build a sequence of ops and set the graph output with output(). For example:

from dataclasses import dataclass

import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops

@dataclass
class Linear:
    weight: np.ndarray
    bias: np.ndarray

    def __call__(self, x: TensorValue) -> TensorValue:
        weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU())
        bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU())
        return ops.matmul(x, weight_tensor) + bias_tensor

linear_graph = Graph(
    "linear",
    Linear(np.ones((2, 2)), np.ones((2,))),
    input_types=[TensorType(DType.float32, (2,))]
)

You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to build a graph with MAX Graph.

When creating a graph, a global sequence of chains is initialized and stored in Graph._current_chain. Every side-effecting op, e.g. buffer_load, store_buffer, load_slice_buffer, store_slice_buffer, will use the current chain to perform the op and and update Graph._current_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.

Parameters:

  • name (str) – A name for the graph.
  • forward (Callable[..., None | Value[Any] | Iterable[Value[Any]]] | None) – The sequence of graph ops for the forward pass (inference).
  • input_types (Iterable[Type[Any]]) – The data type(s) for the input tensor(s).
  • path (Optional[Path]) – The path to a saved graph (internal use only).
  • custom_extensions (list[Path]) – The extensions to load for the model. Supports paths to .mojopkg or .mojo sources with custom ops.
  • context (Optional[mlir.Context])
  • kernel_library (Optional[KernelLibrary])
  • module (Optional[mlir.Module])

add_subgraph()

add_subgraph(name, forward=None, input_types=(), path=None, custom_extensions=[])

Creates and adds a subgraph to the current graph.

Creates a new Graph instance configured as a subgraph of the current graph. The subgraph inherits the parent graph’s MLIR context, module, and symbolic parameters. A chain type is automatically appended to the input types to enable proper operation sequencing within the subgraph.

The created subgraph is marked with special MLIR attributes to identify it as a subgraph and is registered in the parent graph’s subgraph registry.

Parameters:

  • name (str) – The name identifier for the subgraph.
  • forward (Callable[[...], None | Value[Any] | Iterable[Value[Any]]] | None) – The optional callable that defines the sequence of operations for the subgraph’s forward pass. If provided, the subgraph will be built immediately using this callable.
  • input_types (Iterable[Type[Any]]) – The data types for the subgraph’s input tensors. A chain type will be automatically added to these input types.
  • path (Path | None) – The optional path to a saved subgraph definition to load from disk instead of creating a new one.
  • custom_extensions (list[Path]) – The list of paths to custom operation libraries to load for the subgraph. Supports .mojopkg files and Mojo source directories.

Return type:

Graph

add_weight()

add_weight(weight, force_initial_weight_on_host=True)

Adds a weight to the graph.

If the weight is in the graph already, return the existing value.

Parameters:

  • weight (Weight) – The weight to add to the graph.
  • force_initial_weight_on_host (bool) – If true, then forces weights to initially be allocated on host before being moved to the indicated device. This is needed as a stop gap until we have a more fleshed out ownership model of external constants.

Returns:

A TensorValue that contains this weight.

Raises:

ValueError – If a weight with the same name already exists in the graph.

Return type:

TensorValue

always_ready_chain

property always_ready_chain: _ChainValue

A graph-global, immutable chain that is always ready.

Created once per graph and never advanced/merged by the graph itself. Use it for operations that are safe to schedule without threading per-device ordering (e.g., host→device transfers for staging).

current

current

device_chains

device_chains: defaultdict[DeviceRef, _ChainValue]

inputs

property inputs: Sequence[Value[Any]]

The input values of the graph.

kernel_libraries_paths

property kernel_libraries_paths: list[Path]

Returns the list of extra kernel libraries paths for the custom ops.

merge_device_chains()

merge_device_chains()

Joins device execution to a common point by merging chains.

Return type:

None

output()

output(*outputs)

Sets the output nodes of the Graph.

Parameters:

outputs (Value[Any] | Value[TensorType] | TensorValue | Shape | Dim | HasTensorValue | int | float | integer[Any] | floating[Any] | ndarray[Any, dtype[number[Any]]])

Return type:

None

output_types

property output_types: list[Type[Any]]

View of the types of the graph output terminator.

Was this page helpful?