Python class
Graph
Graph
class max.graph.Graph(name: str, forward: Callable | None = None, input_types: Iterable[Type] = (), path: Path | None = None, *args, context: _BaseContext | None = None, **kwargs)
Represents a single MAX graph.
A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions.
When you instantiate a graph, you must specify the input shapes
as one or more TensorType
values. Then, build a
sequence of ops and set the graph output with output()
. For
example:
from dataclasses import dataclass
import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops
@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray
def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32)
bias_tensor = ops.constant(self.bias, dtype=DType.float32)
return ops.matmul(x, weight_tensor) + bias_tensor
linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
from dataclasses import dataclass
import numpy as np
from max.dtype import DType
from max.graph import Graph, TensorType, TensorValue, ops
@dataclass
class Linear:
weight: np.ndarray
bias: np.ndarray
def __call__(self, x: TensorValue) -> TensorValue:
weight_tensor = ops.constant(self.weight, dtype=DType.float32)
bias_tensor = ops.constant(self.bias, dtype=DType.float32)
return ops.matmul(x, weight_tensor) + bias_tensor
linear_graph = Graph(
"linear",
Linear(np.ones((2, 2)), np.ones((2,))),
input_types=[TensorType(DType.float32, (2,))]
)
You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to build a graph with MAX Graph.
When creating a graph, a global sequence of chains is initialized and stored in Graph._current_chain. Every side-effecting op, e.g. buffer_load, store_buffer, load_slice_buffer, store_slice_buffer, will use the current chain to perform the op and and update Graph._current_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races.
add_weight()
add_weight(weight: Weight) → TensorValue
Adds a weight to the graph.
If the weight is in the graph already, return the existing value.
-
Parameters:
weight – The weight to add to the graph.
-
Returns:
A
TensorValue
that contains this weight. -
Raises:
ValueError – If a weight with the same name already exists in the graph.
current
current
inputs
inputs*: tuple[max.graph.value.Value, ...]*
output()
Sets the output nodes of the Graph
.
unique_symbolic_dim()
unique_symbolic_dim(tag: str) → SymbolicDim
Create a new symbolic dim with a different name from any other.
-
Parameters:
tag – An additional identifier to help identify the dimension for debugging purposes.
-
Returns:
The dimension.
weights
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!