Python module
ops
Implements operations used when staging a graph.
This module provides operations for building computational graphs in MAX. These operations create, transform, and manipulate tensor values within the graph.
You can also use functions in Graph to add constant values to your graph with operations like constant().
The TensorValue type (returned by most operations) implements various dunder methods to support operations between TensorValues, such as + for addition, * for multiplication, and @ for matrix multiplication. It also provides convenience methods like reshape() and flatten().
Casting
broadcast_to()
max.graph.ops.broadcast_to(x: TensorValue, shape: TensorValue | Iterable[int | str | Dim | integer], out_dims: Iterable[int | str | Dim | integer] | None = None) → TensorValue
Broadcasts a symbolic tensor.
Broadcasts the input tensor to the specified shape. Dimensions in the input must be one or match the target dimension.
-
Parameters:
- x – The input symbolic tensor to broadcast. This tensor may not contain any dynamic dimensions.
- shape – The new shape as a list of dimensions. Dynamic dimensions are not allowed.
- out_dims – Output dims used only for tensor-valued shape.
-
Returns:
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as
shape
. -
Raises:
ValueError – if a tensor-valued shape is passed without out_dims.
cast()
max.graph.ops.cast(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, dtype: DType) → TensorValue
Casts a symbolic tensor to a different data type.
-
Parameters:
- x – The input tensor to cast.
- dtype – The target dtype to which the tensor is cast.
-
Returns:
A new symbolic tensor with the same shape as the input and the specified dtype.
rebind()
max.graph.ops.rebind(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, shape: Iterable[int | str | Dim | integer], message: str = '') → TensorValue
Rebinds a symbolic tensor to a specified set of dimensions.
This does not mutate the symbolic tensor passed in, but instead adds a
runtime assert that the input symbolic shape is equivalent to
out_dims
shape. For example, if the input tensor shape has
dynamic/unknown sizes, this will assert a fixed sizes that may be required
for a subsequent operation.
-
Parameters:
- x – The input symbolic tensor to rebind.
- shape – The symbolic shape to assert for
x
, as a list ofDim
values. - message – The message printed if the rebind fails at runtime.
-
Returns:
A symbolic tensor with the same elements and shape as the given tensor, but with the symbolic shape asserted to
out_dims
.
reshape()
max.graph.ops.reshape(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, shape: Iterable[int | str | Dim | integer]) → TensorValue
Reshapes a symbolic tensor.
The number and order of the elements in the tensor is unchanged. In other words, if you were to iterate over elements in the tensor by major dimension to minor dimension, the iteration order would stay the same.
If a value of -1 is present in the shape, that dimension becomes an automatically calculated dimension collecting all unspecified dimensions. Its length becomes the number of elements in the original tensor divided by the product of elements of the reshape.
-
Parameters:
- x – The input symbolic tensor to reshape. This tensor may not contain any dynamic dimensions.
- shape – The new shape as a list of dimensions. Dynamic dimensions are not allowed. A single dimension may be -1.
-
Returns:
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as
shape
. -
Raises:
ValueError – if input and target shapes’ number of elements mismatch.
shape_to_tensor()
max.graph.ops.shape_to_tensor(shape: Iterable[int | str | Dim | integer]) → TensorValue
Converts a shape to a tensor.
This is useful for using a shape attribute in an op that expects a tensor value.
-
Parameters:
shape – the shape attribute of a tensor value.
-
Returns:
The TensorValue containing the same value as shape.
Example
>>> x = ops.constant(np.zeros((1,)), DType.int64, device=DeviceRef.CPU())
>>> result = ops.stack([
... x,
... ops.shape_to_tensor(x.shape),
... ])
TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)])
>>> x = ops.constant(np.zeros((1,)), DType.int64, device=DeviceRef.CPU())
>>> result = ops.stack([
... x,
... ops.shape_to_tensor(x.shape),
... ])
TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)])
squeeze()
max.graph.ops.squeeze(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis: int) → TensorValue
Removes a size-1 dimension from a symbolic tensor.
-
Parameters:
- x – The input symbolic tensor to squeeze.
- axis – The dimension to remove from the input’s shape. If negative, this
indexes from the end of the tensor. For example,
squeeze(v, -1)
squeezes the last dimension.
-
Returns:
A symbolic tensor with the same number of elements as the input tensor, and whose rank is 1 less than the rank of the input tensor.
transpose()
max.graph.ops.transpose(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, dim_1: int, dim_2: int) → TensorValue
Transposes two dimensions of a symbolic tensor.
For more information, see transpose()
.
-
Parameters:
- input – The input symbolic tensor to transpose.
- x – One of the two dimensions to transpose. If negative, this indexes
from the end of the tensor. For example,
transpose(v, -1, -2)
transposes the last two dimensions. - y – The other dimension to transpose. May also be negative to index from the end of the tensor.
-
Returns:
A new symbolic tensor with the two specified dimensions transposed. It has the same elements and dtype, but the order of the elements is different according to the transposition.
unsqueeze()
max.graph.ops.unsqueeze(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis: int) → TensorValue
Inserts a size-1 dimension into a symbolic tensor.
-
Parameters:
- x – The input symbolic tensor to unsqueeze.
- axis – The index at which to insert a new dimension into the input’s
shape. Elements at that index or higher are shifted back.
If negative, it indexes relative 1 plus the rank of the tensor.
For example,
unsqueeze(v, -1)
adds a new dimension at the end, andunsqueeze(v, -2)
inserts the dimension immediately before the last dimension.
-
Returns:
A symbolic tensor with the same number of elements as the input tensor, whose rank is 1 larger than the rank of the input tensor. The result’s shape at the
axis
dimension is a static dimension of size 1.
Complex
as_interleaved_complex()
max.graph.ops.as_interleaved_complex(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Reshapes the input symbolic tensor as complex from alternating (real, imag).
-
Parameters:
interleaved – A symbolic tensor representing complex numbers as alternating pairs of (real, imag) real-valued numbers. Its last dimension must have an even size.
-
Returns:
A symbolic tensor representing the complex-valued tensor, but with the values pulled out as complex numbers. The result has the same dimensions for all dimensions except the last dimension, which is halved, and then a final dimension of size 2 representing the complex value.
Constant
constant()
max.graph.ops.constant(value: ndarray | int | float | integer | floating, dtype: DType, device: DeviceRef) → TensorValue
Adds a node representing a constant operation.
The value of this constant will have the type TensorType with the same shape as value. If value is a scalar type, it will create a TensorType with 0 dimensions.
The constant will be loaded with the specified dtype. If the constant does not fit within the specified dtype, an error is raised.
Warning: Loading the constant could result in precision loss. For example, loading 16777217 as a float32 will result in 16777216.0.
-
Parameters:
- value – The constant’s value.
- dtype – The constant tensor’s element type.
- device – The device the constant lives on.
-
Returns:
A graph value containing the constant data as an attribute.
Convolution
conv2d()
max.graph.ops.conv2d(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, filter: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, stride: tuple[int, int] = (1, 1), dilation: tuple[int, int] = (1, 1), padding: tuple[int, int, int, int] = (0, 0, 0, 0), groups: int = 1, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None) → TensorValue
Computes the 2-D convolution product of the input with the given filter, bias, strides, dilations, paddings, and groups.
The op supports 2-D convolution, with the following layout assumptions:
- input x has NHWC layout, i.e., (batch_size, height, width, in_channels)
- filter has layout RSCF, i.e., (height, width, in_channels / num_groups, out_channels)
- bias has shape (out_channels,)
The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2-D convolution, dim1 here represents H and dim2 represents W. In Python like syntax, padding a 2x3 spatial input with [0, 1, 2, 1] would yield:
input = [
[1, 2, 3],
[4, 5, 6]
]
## Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
## Shape is 3x6
input = [
[1, 2, 3],
[4, 5, 6]
]
## Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
## Shape is 3x6
This op currently only supports strides and padding on the input.
-
Parameters:
- input – An NHWC input tensor to perform the convolution upon.
- filter – The convolution filter in RSCF layout: (height, width, in_channels / num_groups, out_channels).
- stride – The stride of the convolution operation.
- dilation – The spacing between the kernel points.
- padding – The amount of padding applied to the input.
- groups – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups.
-
Returns:
A symbolic tensor value with the convolution applied.
conv3d()
max.graph.ops.conv3d(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, filter: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, stride: tuple[int, int, int] = (1, 1, 1), dilation: tuple[int, int, int] = (1, 1, 1), padding: tuple[int, int, int, int, int, int] = (0, 0, 0, 0, 0, 0), groups: int = 1, bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None) → TensorValue
Computes the 3-D convolution product of the input with the given filter, strides, dilations, paddings, and groups.
The op supports 3-D convolution, with the following layout assumptions:
- input has NDHWC layout, i.e., (batch_size, depth, height, width, in_channels)
- filter has layout RSCF, i.e., (depth, height, width, in_channels / num_groups, out_channels)
The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 3-D convolution, dim1 here represents D, dim2 represents H and dim3 represents W. In Python like syntax, padding a 2x3 spatial input with [0, 1, 2, 1] would yield:
input = [
[1, 2, 3],
[4, 5, 6]
]
## Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
## Shape is 3x6
input = [
[1, 2, 3],
[4, 5, 6]
]
## Shape is 2x3
padded_input = [
[0, 0, 1, 2, 3, 0],
[0, 0, 4, 5, 6, 0],
[0, 0, 0, 0, 0, 0]
]
## Shape is 3x6
This op currently only supports strides and padding on the input.
-
Parameters:
- x – An NDHWC input tensor to perform the convolution upon.
- filter – The convolution filter in RSCF layout: (depth, height, width, in_channels / num_groups, out_channels).
- stride – The stride of the convolution operation.
- dilation – The spacing between the kernel points.
- padding – The amount of padding applied to the input.
- groups – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups.
-
Returns:
A symbolic tensor value with the convolution applied. Output shape = (batch_size, depth, height, width, out_channels).
conv2d_transpose()
max.graph.ops.conv2d_transpose(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, filter: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, stride: tuple[int, int] = (1, 1), dilation: tuple[int, int] = (1, 1), padding: tuple[int, int, int, int] = (0, 0, 0, 0), output_paddings: tuple[int, int] = (0, 0), bias: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None) → TensorValue
Computes the 2-D deconvolution of the input with the given filter, strides, dilations, paddings, and groups.
The op supports the transpose (gradient) of convolution, with the following layout assumptions: (note the out_channel is w.r.t. the original convolution)
- input x has NHWC layout, i.e., (batch_size, height, width, in_channels)
- filter has layout RSCF, i.e., (kernel_height, kernel_width, out_channels, in_channels)
- bias has shape (out_channels,)
The padding values are expected to take the form in the form [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]].
This op effectively computes the gradient of a convolution with respect to its input (as if the original convolution operation had the same filter and hyperparameters as this op). A visualization of the computation can be found in https://d2l.ai/chapter_computer-vision/transposed-conv.html.
The padding values are expected to take the form (pad_dim1_before, pad_dim1_after, pad_dim2_before, pad_dim2_after…) and represent padding 0’s before and after the indicated spatial dimensions in input. In 2D ConvTranspose, dim1 here repesents H_out and dim2 represents W_out. In python like syntax, padding a 2x4 spatial output with [0, 1, 2, 1] would yield:
output = [
[1, 2, 3, 4],
[5, 6, 7, 8]
]
## Shape is 2x4
padded_input = [
[3],
]
## Shape is 1x1
output = [
[1, 2, 3, 4],
[5, 6, 7, 8]
]
## Shape is 2x4
padded_input = [
[3],
]
## Shape is 1x1
-
Parameters:
- input – An NHWC input tensor to perform the convolution upon.
- filter – The convolution filter in RSCF layout: (height, width, out_channels, in_channels).
- stride – The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0.
- dilation – The spacing between the kernel points.
- padding – The amount of padding applied to the input.
- output_paddings – this argument is meant to resolve the ambiguity of multiple potential output shapes when any stride is greater than 1. Basically, we’ll add output_paddings[i] number of zeros at the end of output’s ith axis. We only support output_paddings = 0.
- bias – tensor of shape (out_channels,)
-
Returns:
A symbolic tensor value with the convolution applied.
Control flow
cond()
max.graph.ops.cond(pred: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, out_types: Iterable[Type] | None, then_fn: Callable, else_fn: Callable) → list[max.graph.value.TensorValue]
Conditionally execute one of two branches based on a boolean predicate.
Both branches must return the same number and types of values as specified
in out_types
. Buffer mutations in branches are tracked automatically
through the chain mechanism.
Examples:
- Basic conditional with return values:
def then_fn():
return ops.constant(1, DType.int32, device=DeviceRef.CPU())
def else_fn():
return ops.constant(0, DType.int32, device=DeviceRef.CPU())
result = ops.cond(
pred,
[TensorType(DType.int32, [], device=device)],
then_fn,
else_fn
)def then_fn():
return ops.constant(1, DType.int32, device=DeviceRef.CPU())
def else_fn():
return ops.constant(0, DType.int32, device=DeviceRef.CPU())
result = ops.cond(
pred,
[TensorType(DType.int32, [], device=device)],
then_fn,
else_fn
) - Conditional with buffer mutations:
def then_fn():
ops.inplace_custom("increment", [buffer])
def else_fn():
ops.inplace_custom("decrement", [buffer])
ops.cond(pred, None, then_fn, else_fn)def then_fn():
ops.inplace_custom("increment", [buffer])
def else_fn():
ops.inplace_custom("decrement", [buffer])
ops.cond(pred, None, then_fn, else_fn)
::
:param pred: Boolean scalar tensor of type DType.bool
determining branch execution
:param out_types: Expected output types for both branches. Use None
for branches that don’t return values
:param then_fn: Callable executed when pred
is True. Must return values matching out_types
if out_types
is not None
:param else_fn: Callable executed when pred
is False. Must return values matching out_types
if out_types
is not None
-
Returns:
List of output values from executed branch. Returns empty list when
out_types
isNone
-
Raises:
ValueError – If branches return different numbers of results or result types don’t match
out_types
NOTE
Buffer operations in branches automatically update the global chain state to maintain mutation ordering constraints
while_loop()
max.graph.ops.while_loop(initial_values: Iterable[Value] | Value, predicate: Callable[[...], TensorValue], body: Callable[[...], Iterable[Value]]) → list[max.graph.value.TensorValue]
Execute a loop until the predicate evaluates to false.
Both the predicate and body functions must take in as arguments the same
number and types of values as specified in the init_args. The predication
function must return only a boolean scalar tensor of type DType.bool
.
The body function must return a list of values matching the types of init_args.
The following example demonstrates a basic while loop with a single argument:
from max.graph import Graph, ops
from max.dtype import DType
with Graph("while_loop_example") as g:
x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
def pred(x):
return x < 10
def body(x):
return x + 1
result = ops.while_loop(x, pred, body)
print(result)
from max.graph import Graph, ops
from max.dtype import DType
with Graph("while_loop_example") as g:
x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
def pred(x):
return x < 10
def body(x):
return x + 1
result = ops.while_loop(x, pred, body)
print(result)
The following example shows a while loop with multiple arguments:
from max.graph import Graph, ops
from max.dtype import DType
with Graph("while_loop_example") as g:
x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
y = ops.constant(5, dtype=DType.int32, device=DeviceRef.CPU())
def pred(x, y):
return ops.logical_and(x < 10, y < 15)
def body(x, y):
return [x + 1, y + 1]
results = ops.while_loop((x, y), pred, body)
print(results)
from max.graph import Graph, ops
from max.dtype import DType
with Graph("while_loop_example") as g:
x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU())
y = ops.constant(5, dtype=DType.int32, device=DeviceRef.CPU())
def pred(x, y):
return ops.logical_and(x < 10, y < 15)
def body(x, y):
return [x + 1, y + 1]
results = ops.while_loop((x, y), pred, body)
print(results)
-
Parameters:
- initial_values – Initial values for loop arguments. Must be non-empty.
- predicate – Callable that takes loop arguments and returns a boolean scalar tensor
of type
DType.bool
determining loop continuation. - body – Callable that takes loop arguments and returns updated values matching the types of init_args.
-
Returns:
List of output values from the final loop iteration.
-
Raises:
- ValueError – If init_args is empty.
- NotImplementedError – If any init_arg is a
BufferValue
.
NOTE
Buffer operations are currently not supported.
Custom
A custom operation (op) is a user-defined kernel written in Mojo that is registered and executed within the computation graph. It allows you to extend the graph’s capabilities by implementing your own specialized operations.
For example, you might write an add_one_custom
function in Mojo that adds 1
to each element of a matrix. Then you’d call the operation by its string name
in the max.graph.Graph
:
def create_graph(rows: int, columns: int, dtype: DType) -> Graph:
"""Configure a graph with a custom operation."""
graph = Graph(
"addition",
forward=lambda x: ops.custom(
name="add_one_custom",
values=[x],
out_types=[TensorType(dtype=x.dtype, shape=x.tensor.shape)],
)[0].tensor,
input_types=[
TensorType(dtype, shape=[rows, columns]),
],
)
return graph
def create_graph(rows: int, columns: int, dtype: DType) -> Graph:
"""Configure a graph with a custom operation."""
graph = Graph(
"addition",
forward=lambda x: ops.custom(
name="add_one_custom",
values=[x],
out_types=[TensorType(dtype=x.dtype, shape=x.tensor.shape)],
)[0].tensor,
input_types=[
TensorType(dtype, shape=[rows, columns]),
],
)
return graph
Custom ops also support parametrization on int, str, and dtype. This means you can define custom parametric Mojo types then use those types as inputs to custom ops staged in the graph API. For example, given the following Counter Mojo type:
struct Counter[stride: Int](Movable):
var a: Int
var b: Int
fn __init__(out self):
self.a = 0
self.b = 0
fn __init__(out self, a: Int, b: Int):
self.a = a
self.b = b
fn __moveinit__(out self, owned other: Self):
self.a = other.a
self.b = other.b
fn bump(mut self):
self.a += Self.stride
self.b += self.a
struct Counter[stride: Int](Movable):
var a: Int
var b: Int
fn __init__(out self):
self.a = 0
self.b = 0
fn __init__(out self, a: Int, b: Int):
self.a = a
self.b = b
fn __moveinit__(out self, owned other: Self):
self.a = other.a
self.b = other.b
fn bump(mut self):
self.a += Self.stride
self.b += self.a
The following inplace_custom()
call stages an op that
bumps the parametric Counter
type. Notice that we’re using
_OpaqueType
here, which is a Python-based graph type that
represents a Mojo value (from max.graph.type
), but it’s
currently an internal API and subject to change.
counter_type = _OpaqueType("Counter")
## ... create counter object.
## Stage a graph that bumps the counter, parametrized on stride.
bumper_graph = Graph(
"bumper",
forward=lambda x: ops.inplace_custom(
"bump_counter",
values=[x],
out_types=[],
parameters={"stride": 2},
),
input_types=[counter_type],
)
counter_type = _OpaqueType("Counter")
## ... create counter object.
## Stage a graph that bumps the counter, parametrized on stride.
bumper_graph = Graph(
"bumper",
forward=lambda x: ops.inplace_custom(
"bump_counter",
values=[x],
out_types=[],
parameters={"stride": 2},
),
input_types=[counter_type],
)
custom()
max.graph.ops.custom(name: str, values: Sequence[Value], out_types: Sequence[Type], parameters: Mapping[str, bool | int | str | DType] | None = None, device: DeviceRef | None = None) → list[max.graph.value.Value]
Creates a node to execute a custom graph operation in the graph.
The custom op should be registered by annotating a function with the @compiler.register decorator.
-
Parameters:
- name – The op name provided to
@compiler.register
. - values – The op function’s arguments.
- out_types – The list of op function’s return type.
- parameters – Dictionary of extra parameters expected by the kernel.
- device – Device that the op is assigned to. This becomes a target parameter to the kernel.
- name – The op name provided to
-
Returns:
Symbolic values representing the outputs of the op in the graph. These correspond 1:1 with the types passed as
out_types
.
inplace_custom()
max.graph.ops.inplace_custom(name: str, values: Iterable[Value], out_types: Iterable[Type] | None = None, parameters: dict[str, bool | int | str | max._core.dtype.DType] | None = None, device: DeviceRef | None = None) → list[max.graph.value.Value]
Creates a node to execute an in-place custom graph operation in the graph.
The custom op should be registered by annotating a function with the @compiler.register decorator.
-
Parameters:
- name – The op name provided to
@compiler.register
. - values – The op function’s arguments.
- parameters – Dictionary of extra parameters expected by the kernel.
- device – Device that the op is assigned to. This becomes a target parameter to the kernel.
- name – The op name provided to
Debug
Operations used to help debug your graph.
print()
max.graph.ops.print(value: str | TensorValue, label: str = 'debug_tensor')
Prints the value of a tensor or a string during graph execution.
This function is used to output the current value of a tensor and is primarily used for debugging purposes within the context of the Max Engine and its graph execution framework. This is particularly useful to verify the intermediate results of your computations are as expected.
By printing the tensor values, you can visualize the data flowing through the graph, which helps in understanding how the operations are transforming the data.
When labeling the function you can assign the output, making it easier to identify which tensor’s value is being printed, especially when there are multiple print statements in a complex graph.
def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, Any]:
input_type = TensorType(dtype=DType.float32, shape=(1,), device=DeviceRef.CPU())
with Graph(
"simple_add_graph", input_types=(input_type, input_type)
) as graph:
lhs, rhs = graph.inputs
out = ops.add(lhs, rhs)
ops.print(out, label="addition_output") # Pass the output tensor here
graph.output(out)
print("final graph:", graph)
def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, Any]:
input_type = TensorType(dtype=DType.float32, shape=(1,), device=DeviceRef.CPU())
with Graph(
"simple_add_graph", input_types=(input_type, input_type)
) as graph:
lhs, rhs = graph.inputs
out = ops.add(lhs, rhs)
ops.print(out, label="addition_output") # Pass the output tensor here
graph.output(out)
print("final graph:", graph)
-
Parameters:
- value – The value to print. Can be either a string or a TensorValue.
- label – A label to identify the printed value. Defaults to
debug_tensor
.
Distributed
allgather()
max.graph.ops.allgather(inputs: Iterable[TensorValue], dim: int = 0) → list[max.graph.value.TensorValue]
Collective allgather operation.
This op is a collective op which takes in tensors from different devices and outputs tensors on different devices. In particular, this operation will gather the inputs across different devices and concatenates them along the 0th dimension. The result is then broadcasted back to the same devices that the inputs came from.
-
Parameters:
- inputs – The input tensors to gather.
- dim – Dimension to concatenate the input tensors. Defaults to 0.
-
Returns:
An iterable outputs which all hold the gathered output. Each output is a rank-1 array.
sum()
max.graph.ops.allreduce.sum(inputs: Iterable[TensorValue], signal_buffers: Iterable[BufferValue]) → list[max.graph.value.TensorValue]
Collective allreduce summation operation.
This op is a collective op which takes in tensors from different devices and outputs tensors on different devices. In particular, this operation will gather the inputs across different devices and reduce them via a summation operation. The result is then broadcasted back to the same devices that the inputs came from.
This version of the allreduce sum op uses device-to-device transfers and
hence is expected to be much slower than the ops.allreduce.sum
version.
-
Parameters:
- inputs – The input tensors to reduce.
- signal_buffers – Device buffer values used for synchronization.
-
Returns:
An iterable outputs which all hold the reduction output.
Elementwise
An elementwise operation performs the same calculation on each element of an input tensor. These operations take tensors of compatible shapes and apply the specified operation to each element pair.
For example, the following demonstrates how to add two tensors using the add()
function:
import numpy as np
from max import engine
from max.dtype import DType
from max.graph import Graph, TensorType, ops
def main():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("simple_add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First operand
y = graph.inputs[1] # Second addend
out = ops.add(x, y)
graph.output(out)
session = engine.InferenceSession()
model = session.load(graph)
input_0 = np.array([10.0, 8.0], dtype=np.float32)
input_1 = np.array([2.0, 4.0], dtype=np.float32)
ret = model.execute(input_0, input_1)
print("\nAddition computation:")
print("Result ", ret["output0"])
if __name__ == "__main__":
main()
import numpy as np
from max import engine
from max.dtype import DType
from max.graph import Graph, TensorType, ops
def main():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("simple_add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First operand
y = graph.inputs[1] # Second addend
out = ops.add(x, y)
graph.output(out)
session = engine.InferenceSession()
model = session.load(graph)
input_0 = np.array([10.0, 8.0], dtype=np.float32)
input_1 = np.array([2.0, 4.0], dtype=np.float32)
ret = model.execute(input_0, input_1)
print("\nAddition computation:")
print("Result ", ret["output0"])
if __name__ == "__main__":
main()
abs()
max.graph.ops.abs(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise absolute value of a symbolic tensor.
Creates a new op node to compute the elementwise absolute value of a symbolic tensor and adds it to the graph, returning the symbolic result.
The following demonstrates how to compute the absolute value using the abs()
function:
def abs_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("abs_graph", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.abs(x)
graph.output(out)
def abs_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("abs_graph", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.abs(x)
graph.output(out)
-
Parameters:
value – The symbolic tensor to use as the input to the absolute value computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
add()
max.graph.ops.add(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Adds two symbolic tensors.
Creates a new op node to compute the addition of two symbol tensor values and adds it to the graph, returning the symbolic result.
The following shows an example of the add() function with two inputs:
def add_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0]
y = graph.inputs[1]
out = ops.add(x, y)
graph.output(out)
def add_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0]
y = graph.inputs[1]
out = ops.add(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according : to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the : same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
- location – An optional location for a more specific error message.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
cos()
max.graph.ops.cos(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise cosine of a symbolic tensor.
Creates a new op node to compute the elementwise cosine of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the cos computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the cosine value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
div()
max.graph.ops.div(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Divides two symbolic tensors.
Creates a new op node to compute the division of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according : to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the : same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the division.
- rhs – The symbol to use as right side of the division.
- location – An optional location for a more specific error message.
-
Returns:
A symbolic tensor value representing the output of the division. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
equal()
max.graph.ops.equal(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise equality comparison between two symbolic tensors.
Creates a new op node to compute the equality comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
def equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(3,), device=DeviceRef.CPU())
with Graph("equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First input
y = graph.inputs[1] # Second input
out = ops.equal(x, y)
graph.output(out)
def equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(3,), device=DeviceRef.CPU())
with Graph("equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First input
y = graph.inputs[1] # Second input
out = ops.equal(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the equality comparison.
- rhs – The symbol to use as right side of the equality comparison.
-
Returns:
A symbolic tensor value representing the output of the equality comparison. The result will have:
- the same dtype as the type promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
erf()
max.graph.ops.erf(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise error function of a symbolic tensor.
Creates a new op node to compute the elementwise error function of a symbolic tensor and adds it to the graph, returning the symbolic result.
The error function erf
is defined as the probability that a randomly
sampled normal distribution falls within a given range.
-
Parameters:
value – The symbolic tensor to use as the input to the error function computation.
-
Returns:
A new symbolic tensor value representing the output of the error function value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
exp()
max.graph.ops.exp(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise exp function of a symbolic tensor.
Creates a new op node to compute the elementwise exp function of a symbolic tensor and adds it to the graph, returning the symbolic result.
exp
is defined as exp(x) = e^x
, where e
is Euler’s number.
-
Parameters:
value – The symbolic tensor to use as the input to the exp function computation.
-
Returns:
A new symbolic tensor value representing the output of the exp : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
floor()
max.graph.ops.floor(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise floor of a symbolic tensor.
Creates a new op node to compute the elementwise floor of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the floor computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the floor value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
gelu()
max.graph.ops.gelu(x: TensorValue, approximate: str = 'none')
Computes the elementwise gelu of a symbolic tensor.
Creates a new op node to compute the elementwise gelu of a symbolic tensor and adds it to the graph, returning the symbolic result.
For approximate == "none"
, the exact gelu function is computed.
For approximate == "tanh"
, the approximation:
is used.
For approximate == "quick"
, the approximation:
is used.
-
Parameters:
value – The symbolic tensor to use as the input to the gelu computation.
-
Returns:
A new symbolic tensor value representing the output of the gelu : value computation.
-
Raises:
- Error – If the symbol doesn’t represent a tensor value.
- ValueError – If the approximation method is invalid.
greater()
max.graph.ops.greater(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise greater than comparison between two symbolic tensors.
Creates a new op node to compute the greater than comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
def greater_than_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("greater_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.greater(x, y)
graph.output(out)
def greater_than_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("greater_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.greater(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the greater than comparison.
- rhs – The symbol to use as right side of the greater than comparison.
-
Returns:
A symbolic tensor value representing the output of the greater than comparison. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
greater_equal()
max.graph.ops.greater_equal(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise greater-or-equal comparison between two symbolic tensors.
Creates a new op node to compute the equality comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the greater-or-equal comparison.
- rhs – The symbol to use as right side of the greater-or-equal comparison.
-
Returns:
A symbolic tensor value representing the output of the greater-or-equal comparison. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
is_inf()
max.graph.ops.is_inf(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise is_inf of a symbolic tensor.
Creates a new op node to compute the elementwise is_inf of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the is_inf computation.
-
Returns:
- element type
bool
, true if the element at a given position : is plus or minus infinity, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Raises – If the symbol doesn’t represent a tensor value.
is_nan()
max.graph.ops.is_nan(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise is_nan of a symbolic tensor.
Creates a new op node to compute the elementwise is_nan of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the is_nan computation.
-
Returns:
- element type
bool
, true if the element at a given position : is NaN, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
log()
max.graph.ops.log(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise natural logarithm of a symbolic tensor.
Creates a new op node to compute the elementwise natural logarithm of a symbolic tensor and adds it to the graph, returning the symbolic result.
The natural logarithm function log
is defined as the inverse of the
exponential function exp()
. In other words, it computes the value y
in
the equation x = e^y
where e
is Euler’s number.
log(x)
is undefined for x <= 0
for real numbers. Complex numbers
are currently unsupported.
-
Parameters:
value – The symbolic tensor to use as the input to the natural logarithm computation.
-
Returns:
A new symbolic tensor value representing the output of the natural logarithm : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
log1p()
max.graph.ops.log1p(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise logarithm of 1 plus a symbolic tensor.
Creates a new op node to compute the elementwise log1p of a symbolic tensor and adds it to the graph, returning the symbolic result.
The log1p
function is defined as log1p(x) = log(1 + x)
, where log()
is the natural logarithm.
Using log1p(x)
rather than computing log(1 + x)
can give greater
numerical precision results.
log(x)
is undefined for x <= 0
for real numbers. Complex numbers
are currently unsupported.
-
Parameters:
value – The symbolic tensor to use as the input to the log1p computation.
-
Returns:
A new symbolic tensor value representing the output of the log1p : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
logical_not()
max.graph.ops.logical_not(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise logical_not of a symbolic tensor.
Creates a new op node to compute the elementwise logical_not of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the logical_not computation.
-
Returns:
- element type
bool
, true if the element at a given position : is plus or minus infinity, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
logsoftmax()
max.graph.ops.logsoftmax(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise logsoftmax of a symbolic tensor.
Creates a new op node to compute the elementwise logsoftmax of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the logsoftmax computation.
-
Returns:
A new symbolic tensor value representing the output of the logsoftmax : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
max()
max.graph.ops.max(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, y: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, /, axis: int | None = None) → TensorValue
Overload for ops.elementwise.max and ops.reduction.max.
- If two tensors are provided, axis is ignored and returns an elementwise maximum.
- If one tensor is provided, compute ops.reduction.max on the tensor and axis.
min()
max.graph.ops.min(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, y: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, /, axis: int | None = None) → TensorValue
Overload for ops.elementwise.min and ops.reduction.min.
- If two tensors are provided, axis is ignored and returns an elementwise minimum.
- If one tensor is provided, compute ops.reduction.min on the tensor and axis.
mod()
max.graph.ops.mod(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise modulus of two symbolic tensors.
Creates a new op node to compute the modulus of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the modulus operation.
- rhs – The symbol to use as right side of the modulus operation.
-
Returns:
A symbolic tensor value representing the output of the modulus operation. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
mul()
max.graph.ops.mul(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise multiplication of two symbolic tensors.
Creates a new op node to compute the multiplication of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the multiplication.
- rhs – The symbol to use as right side of the multiplication.
-
Returns:
A symbolic tensor value representing the output of the multiplication. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
negate()
max.graph.ops.negate(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise negation of a symbolic tensor.
Creates a new op node to compute the elementwise negation of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the negation computation.
-
Returns:
- element type
bool
, true if the element at a given position : is plus or minus infinity, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
not_equal()
max.graph.ops.not_equal(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise inequality comparison between two symbolic tensors.
Creates a new op node to compute the inequality comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
def not_equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("not_equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.not_equal(x, y)
graph.output(out)
def not_equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("not_equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.not_equal(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the inequality comparison.
- rhs – The symbol to use as right side of the inequality comparison.
-
Returns:
A symbolic tensor value representing the output of the inequality comparison. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
outer()
max.graph.ops.outer(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the outer product of two symbolic vectors.
-
Parameters:
- lhs – The left side of the product. Whatever its shape, it will be flattened to a rank-1 vector.
- rhs – The right side of the product. Whatever its shape, it will be flattened to a rank-1 vector. Must have the same number of elements as lhs.
-
Returns:
A symbolic tensor representing the outer product of the two input vectors. It will have rank 2, with the dimension sizes being the number of elements of lhs and rhs respectively.
pow()
max.graph.ops.pow(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise exponentiation of two symbolic tensors.
Creates a new op node to compute the exponentiation of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the exponentiation.
- rhs – The symbol to use as right side of the exponentiation.
-
Returns:
A symbolic tensor value representing the output of the exponentiation. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
relu()
max.graph.ops.relu(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise relu of a symbolic tensor.
Creates a new op node to compute the elementwise relu of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the relu computation.
-
Returns:
A new symbolic tensor value representing the output of the relu : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
round()
max.graph.ops.round(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise round of a symbolic tensor.
Creates a new op node to compute the elementwise round of a symbolic tensor and adds it to the graph, returning the symbolic result. Rounding is done with ties towards the nearest even number.
For example, if the model has one input tensor:
def round_graph():
input_type = TensorType(dtype=DType.float32, shape=(4,), device=DeviceRef.CPU())
with Graph("round_graph_example", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.round(x)
graph.output(out)
def round_graph():
input_type = TensorType(dtype=DType.float32, shape=(4,), device=DeviceRef.CPU())
with Graph("round_graph_example", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.round(x)
graph.output(out)
-
Parameters:
value – The symbolic tensor to use as the input to the round computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the round value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
rsqrt()
max.graph.ops.rsqrt(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise inverse-square-root of a symbolic tensor.
Creates a new op node to compute the elementwise rsqrt of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the rsqrt computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the rsqrt : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sigmoid()
max.graph.ops.sigmoid(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise sigmoid of a symbolic tensor.
Creates a new op node to compute the elementwise sigmoid of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the sigmoid computation.
-
Returns:
A new symbolic tensor value representing the output of the sigmoid : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
silu()
max.graph.ops.silu(x: TensorValue)
Computes the elementwise silu of a symbolic tensor.
Creates a new op node to compute the elementwise silu of a symbolic tensor and adds it to the graph, returning the symbolic result.
silu
is defined as silu(x) = x * sigmoid(x)
.
-
Parameters:
value – The symbolic tensor to use as the input to the silu computation.
-
Returns:
A new symbolic tensor value representing the output of the silu : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sin()
max.graph.ops.sin(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise sine of a symbolic tensor.
Creates a new op node to compute the elementwise sine of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the sin computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the sin : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
softmax()
max.graph.ops.softmax(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise softmax of a symbolic tensor.
Creates a new op node to compute the elementwise softmax of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the softmax computation.
-
Returns:
A new symbolic tensor value representing the output of the softmax : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sqrt()
max.graph.ops.sqrt(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise sqrt of a symbolic tensor.
Creates a new op node to compute the elementwise sqrt of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the sqrt computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the sqrt : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sub()
max.graph.ops.sub(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise subtraction of two symbolic tensors.
Creates a new op node to compute the subtraction of two symbol tensor values and adds it to the graph, returning the symbolic result.
def sub_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("sub_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Minuend (number being subtracted from)
y = graph.inputs[1] # Subtrahend (number being subtracted)
out = ops.sub(x, y)
graph.output(out)
def sub_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,), device=DeviceRef.CPU())
with Graph("sub_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Minuend (number being subtracted from)
y = graph.inputs[1] # Subtrahend (number being subtracted)
out = ops.sub(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the subtraction.
- rhs – The symbol to use as right side of the subtraction.
-
Returns:
A symbolic tensor value representing the output of the subtraction. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
tanh()
max.graph.ops.tanh(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise tanh of a symbolic tensor.
Creates a new op node to compute the elementwise tanh of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the tanh computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the tanh : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
trunc()
max.graph.ops.trunc(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise truncation of a symbolic tensor.
Creates a new op node to compute the elementwise truncation of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the truncation computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the truncation : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
Fast fourier transforms
irfft()
max.graph.ops.irfft(input_tensor: TensorValue, n: int | None = None, axis: int = -1, normalization: Normalization | str = Normalization.BACKWARD, input_is_complex: bool = False)
Compute the inverse real FFT of the input tensor.
-
Parameters:
- input_tensor – The input tensor to compute the inverse real FFT of.
- n – The size of the output tensor. Must be an int, and cannot be a symbolic Tensor. The input tensor will be padded or truncated to n // 2 + 1 along the specified axis.
- axis – The axis to compute the inverse real FFT of.
- normalization – The normalization to apply to the output tensor. Can be “backward”, “ortho”, or “forward”. When “backward”, the output is divided by n. When “ortho”, the output is divided by sqrt(n). When “forward”, no normalization is applied.
- input_is_complex – Whether the input tensor is already interleaved complex. The last dimension of the input tensor must be 2, and is excluded from the dimension referred to by axis.
-
Returns:
The inverse real FFT of the input tensor. The shape of the output tensor is the same as the shape of the input tensor, except for the axis that the inverse real FFT is computed over, which is replaced by n.
Linalg
band_part()
max.graph.ops.band_part(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, num_lower: int, num_upper: int, exclude: bool = False) → TensorValue
Masks out everything except a diagonal band of an input matrix.
Copies a tensor setting everything outside the central diagonal band of the matricies to zero, where all but the last two axes are effectively batches, and the last two axes define sub matricies.
Assumes the input has dimensions [I, J, …, M, N], then the output tensor has the same shape as the input, and the values are given by
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
with the indicator function:
in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))
in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))
-
Parameters:
- input – The input to mask out.
- num_lower – The number of diagonal bands to include below the central diagonal. If -1, include the entire lower triangle.
- num_upper – The number of diagonal bands to include above the central diagonal. If -1, include the entire upper triangle.
- exclude – If true, invert the selection of elements to mask. Elements in the band are set to zero.
-
Returns:
A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.
layer_norm()
max.graph.ops.layer_norm(input: TensorValue, gamma: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, beta: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, epsilon: float) → TensorValue
Performs layer normalization.
-
Parameters:
- input – The input tensor to normalize.
- gamma – The gamma parameter of the normalization.
- beta – The beta parameter of the normalization.
- epsilon – The epsilon parameter of the normalization.
-
Returns:
A graph tensor value with the normalization applied.
matmul()
max.graph.ops.matmul(lhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the matrix multiplication of two tensor graph values.
Performs general matrix multiplication with broadcasting.
If the lhs is 1D, it will be reshaped to 1xD
.
If the rhs is 1D, it will be reshaped to Dx1
.
In both cases, the additional 1 dimensions will be removed from the
output shape.
For the multiplication, the innermost (rightmost) 2 dimensions are treated
as a matrix.
The lhs matrix will have the shape MxK
.
The rhs matrix will have the shape KxN
.
The output will have the shape MxN
The K
dimensions must be equivalent in both matrices.
The remaining outer dimensions will be broadcasted.
-
Parameters:
- lhs – The left-hand-side of the matmul.
- rhs – The right-hand-side of the matmul.
- location – An optional location for a more specific error message.
-
Returns:
A tensor graph value representing he result of broadcasting the two matricies together and then performing a matrix multiply along the innermost two dimension of each tensor.
Buffer operations
buffer_load()
max.graph.ops.buffer_load(x: BufferValue) → TensorValue
Loads the input buffer into a tensor.
It loads the in-place mutable tensor to an immutable tensor graph value. This is semantically equivalent to a copy from the mutable tensor x to the mutable value-semantic tensor output.
-
Parameters:
x – The buffer to be loaded to a tensor.
-
Returns:
A tensor graph value representing a copy of the buffer loaded.
buffer_store()
max.graph.ops.buffer_store(destination: BufferValue, source: TensorValue) → None
Stores the input tensor into the inout buffer.
It stores the immutable input tensor x in the mutable tensor y. This is semantically equivalent to a copy from x tensor to the y buffer.
-
Parameters:
- x – The tensor to be stored in the buffer.
- y – The buffer to store the tensor in.
buffer_store_slice()
max.graph.ops.buffer_store_slice(destination: BufferValue, source: TensorValue, indices: Sequence[TensorValue | int | slice | tuple[slice, Union[int, str, max.graph.type.Dim, numpy.integer]] | EllipsisType]) → None
Stores the input tensor to into a slice in the input buffer.
It stores the immutable input tensor source in the mutable tensor destination. This is semantically equivalent to a copy from source tensor to a slice in the destination buffer at index specified by indices.
-
Parameters:
- destination – The buffer to store the tensor in.
- source – The tensor to be stored in the buffer.
- indices – The index in the buffer where the tensor should be stored
Call operations
call()
max.graph.ops.call(graph: Graph, *args: Value | Value) → list[max.graph.value.Value]
Call a graph with the provided arguments and return its results.
This function invokes a previously defined graph, passing in the provided arguments and the current chain value, and returns the results.
The body of the graph is ultimately inlined into the caller, so the chain value is only used for serialization if the subgraph’s body contains an operation that makes use of it in the first place.
The current advantage of using subgraphs is that it offers a way to improve compile times for operations that are used repeatedly in a model. As a secondary benefit, it also makes the IR more readable by allowing control flow to be expressed in a more natural way.
-
Parameters:
- graph – The graph to call
- *args – Arguments to pass to the called graph
-
Returns:
Either a single Value or a list of Values representing the graph outputs (excluding the chain value which is handled internally)
Flatten
flatten()
max.graph.ops.flatten(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, start_dim: int = 0, end_dim: int = -1) → TensorValue
Flattens the specified dims of a symbolic tensor.
The number and order of the elements in the tensor is unchanged. All dimensions from start_dim to end_dim (inclusive) are merged into a single output dim.
Fold
fold()
max.graph.ops.fold(input: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, output_size: tuple[int, int], kernel_size: tuple[int, int], stride: int | tuple[int, int] = 1, dilation: int | tuple[int, int] = 1, padding: int | tuple[int, int] = 0) → TensorValue
Combines an array of sliding blocks into a larger containing tensor.
The input tensor must have shape (N, C * kernel_sizes, L)
where N
is
the batch dimension, C
is the number of channels, kernel_sizes
is
the product of the kernel sizes, and L
is the number of local blocks.
The resulting output tensor will have shape
(N, C, output_shape[0], output_shape[1])
.
L
, the number of blocks, must be equivalent to:
prod((output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)
where d
is over all spatial dimensions.
-
Parameters:
- input – The 3D tensor to fold with shape
(N, C * kernel sizes, L)
. - output_size – Spacial dimensions of the output tensor. Must be a tuple of two ints.
- kernel_size – The size of the sliding blocks. Must be a tuple of two ints.
- stride – The stride of the sliding blocks in the input dimension (can be an int or a tuple of two ints).
- dilation – The spacing between the kernel elements. (can be an int or a tuple of two ints).
- padding – 0-paddings to be added on both sides of the inputs. (can be an int or a tuple of two ints).
- input – The 3D tensor to fold with shape
-
Returns:
The folded 4D tensor with shape
(N, C, output_shape[0], output_shape[1])
.
Pad
pad()
max.graph.ops.pad(input: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, paddings: Iterable[int], mode: Literal['constant'] = 'constant', value: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray = 0) → TensorValue
Permute
permute()
max.graph.ops.permute(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, dims: list[int]) → TensorValue
Permutes all dimensions of a symbolic tensor.
-
Parameters:
- input – The input symbolic tensor to transpose.
- dims – The desired ordering of the dimensions in the output tensor.
-
Returns:
A new symbolic tensor with the dimensions permuted to match the passed in order. It has the same elements and dtype, but the order of the elements is different according to the permutation.
Quantized
dequantize()
max.graph.ops.dequantize(encoding: QuantizationEncoding, quantized: TensorValue) → TensorValue
Dequantizes a quantized tensor to floating point.
NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.
-
Parameters:
- encoding – The quantization encoding to use.
- quantized – The quantized tensor to dequantize.
-
Returns:
The dequantized result (a floating point tensor).
qmatmul()
max.graph.ops.qmatmul(encoding: QuantizationEncoding, config: QuantizationConfig | None, lhs: TensorValue, *rhs: TensorValue) → TensorValue
Performs matrix multiplication between floating point and quantized tensors.
This quantizes the lhs
floating point value to match the encoding of the
rhs
quantized value, performs matmul, and then dequantizes the result.
Beware that, compared to a regular matmul op, this one expects the rhs
value to be transposed. For example, if the lhs
shape is [32, 64], and
the quantized rhs
shape is also [32, 64]
, then the output shape is
[32, 32]
.
That is, this function returns the result from:
dequantize(quantize(lhs) @ transpose(rhs))
The last two dimensions in lhs
are treated as matrices and multiplied
by rhs
(which must be a 2D tensor). Any remaining dimensions in lhs
are broadcast dimensions.
NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.
-
Parameters:
- encoding – The quantization encoding to use.
- lhs – The non-quantized, left-hand-side of the matmul.
- *rhs – The transposed and quantized right-hand-side of the matmul and auxiliary tensor (if has). Must be rank 2 and in a supported [quantization encoding] (/max/api/mojo/graph/quantization/).
-
Returns:
The dequantized result (a floating point tensor).
Range
range()
max.graph.ops.range(start: ~max._mlir._mlir_libs._mlir.ir.Value | ~max.graph.value.BufferValue | ~max.graph.value.TensorValue | ~max.graph.type.Shape | ~max.graph.type.Dim | int | float | ~numpy.integer | ~numpy.floating | ~numpy.ndarray, stop: ~max._mlir._mlir_libs._mlir.ir.Value | ~max.graph.value.BufferValue | ~max.graph.value.TensorValue | ~max.graph.type.Shape | ~max.graph.type.Dim | int | float | ~numpy.integer | ~numpy.floating | ~numpy.ndarray, step: ~max._mlir._mlir_libs._mlir.ir.Value | ~max.graph.value.BufferValue | ~max.graph.value.TensorValue | ~max.graph.type.Shape | ~max.graph.type.Dim | int | float | ~numpy.integer | ~numpy.floating | ~numpy.ndarray, out_dim: int | str | ~max.graph.type.Dim | ~numpy.integer | None = None, device: ~max.graph.type.DeviceRef = cpu:0, dtype: ~max._core.dtype.DType = DType.float32) → TensorValue
Creates a sequence of numbers. The sequence goes from start with increments of size step up to (but not including) stop. All arguments are mandatory and must have the same element type.
Note the following restrictions on input values:
- step must be non-zero
- stop - start must be zero or have the same sign as step
-
Parameters:
- start – The start of the range to generate.
- stop – The range will be generated up to, but not including, this value.
- step – The step size for the range.
- out_dim – The expected output dimensions returned by the range op. These will be assert at graph execution time to be correct.
- device – Device of the result tensor.
-
Returns:
A symbolic tensor value containing the defined range of values.
Repeat
repeat_interleave()
max.graph.ops.repeat_interleave(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, repeats: int | TensorValue, axis: int | None = None, out_dim: int | str | Dim | integer | None = None) → TensorValue
Repeats elements of a tensor along the given dimension.
Modeled after torch.repeat_interleave
, with the constraint that
For example, given repeats=2
and the following input:
## Input tensor with shape (2, 2)
input = TensorValue(x) # Contains [[1.0, 2.0], [3.0, 4.0]]
## Input tensor with shape (2, 2)
input = TensorValue(x) # Contains [[1.0, 2.0], [3.0, 4.0]]
repeat_interleave
with axis=0
:
## Output tensor with shape (4, 2)
output = repeat_interleave(input, repeats=2, axis=0)
## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]
## Output tensor with shape (4, 2)
output = repeat_interleave(input, repeats=2, axis=0)
## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]
repeat_interleave
with axis=1
:
## Output tensor with shape (2, 4)
output = repeat_interleave(input, repeats=2, axis=1)
## Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]]
## Output tensor with shape (2, 4)
output = repeat_interleave(input, repeats=2, axis=1)
## Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]]
repeat_interleave
with axis=None
(the default):
repeat_interleave
with repeats=[2, 3]
and axis=0
:
repeat_value = TensorValue([2, 3])
## Output tensor with shape (5, 2)
output = repeat_interleave(input, repeats=repeat_value, axis=0)
## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
repeat_value = TensorValue([2, 3])
## Output tensor with shape (5, 2)
output = repeat_interleave(input, repeats=repeat_value, axis=0)
## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
## Output tensor with shape (8,)
output = repeat_interleave(input, repeats=2) # axis = None
## Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
## Output tensor with shape (8,)
output = repeat_interleave(input, repeats=2) # axis = None
## Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
-
Parameters:
- x – The input tensor.
- repeats – The number of repetitions for each element.
- axis – The dimension along which to repeat values. If axis is not specified or None (the default), flatten the input array and repeat the flattened values.
-
Returns:
A symbolic tensor with the elements interleaved.
-
Raises:
ValueError – If
repeats
non-positive or ifaxis
is out of range.
Tile
tile()
max.graph.ops.tile(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, repeats: Iterable[int | str | Dim | integer]) → TensorValue
Returns a new Tensor as the result of copying the input tensor N_i times on each dimension, where N_i = repeats[i].
The i-th dimension of output shape will be the ith dimension of input shape multiplied by N_i.
Transfer
transfer_to()
max.graph.ops.transfer_to(x: TensorValue, device: DeviceRef) → TensorValue
Device-to-Device transfer operation.
This op transfers the input tensor from its current device over to another. A device represents a computation unit, like CPU, GPU, etc. This op is useful for instance when working with accelerators, like GPU, where for instance one may need to move data from GPU to GPU, or from one GPU to CPU.
-
Parameters:
- x – The input tensor to transfer.
- device – The device to transfer to.
-
Returns:
A tensor transferred to device specified.
TopK
top_k()
max.graph.ops.top_k(input: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, k: int, axis: int = -1) → tuple[max.graph.value.TensorValue, max.graph.value.TensorValue]
Returns tensor with only top K values along given axis.
-
Parameters:
- input – The input tensor from which to select top k.
- k – The number of values to select from input.
- axis – The axis from which to select top k.
-
Returns:
Top K values, Top K indices
Reduction
argmax()
max.graph.ops.argmax(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis=-1) → TensorValue
Reduces a symbolic tensor using an argmax operation.
When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index.
-
Parameters:
- x – The input tensor for the operation.
- axis – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
-
Returns:
A symbolic tensor representing the result of the argmax operation. The tensor will have the same rank as the input tensor, and the same shape except along the
axis
dimension which will have size 1.
argmin()
max.graph.ops.argmin(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis=-1) → TensorValue
Reduces a symbolic tensor using an argmin operation.
When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index.
-
Parameters:
- x – The input tensor for the operation.
- axis – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
-
Returns:
A symbolic tensor representing the result of the argmin operation. The tensor will have the same rank as the input tensor, and the same shape except along the
axis
dimension which will have size 1.
mean()
max.graph.ops.mean(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis=-1) → TensorValue
Reduces a symbolic tensor using a mean operation.
-
Parameters:
- x – The input tensor for the operation.
- axis – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
-
Returns:
A symbolic tensor representing the result of the mean operation. The tensor will have the same rank as the input tensor, and the same shape except along the
axis
dimension which will have size 1.
sum()
max.graph.ops.sum(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis=-1) → TensorValue
Reduces a symbolic tensor using a sum operation.
-
Parameters:
- x – The input tensor for the operation.
- axis – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
-
Returns:
A symbolic tensor representing the result of the sum operation. The tensor will have the same rank as the input tensor, and the same shape except along the
axis
dimension which will have size 1.
Indexing
argsort()
max.graph.ops.argsort(x: TensorValue, ascending: bool = True) → TensorValue
Returns the indices that would sort a tensor.
This function returns the indices that would sort the input tensor along its first dimension. The returned indices are of type int64.
-
Parameters:
- x – Input tensor to be sorted.
- ascending – If True (default), sort in ascending order. If False, sort in descending order.
-
Returns:
A tensor of indices of the same shape as the input tensor.
nonzero()
max.graph.ops.nonzero(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, out_dim: int | str | Dim | integer) → TensorValue
Returns the indices of all nozero elements in a tensor.
Returns a tensor of indices of the nonzero values in the given tensor. The return value is a 2D tensor of shape [out_dim x rank_in], where out_dim is the number of nonzero elements in the input tensor, and rank_in is the rank of the input tensor. Indices are generated in row-major order.
-
Parameters:
- x – The input symbolic tensor.
- out_dim – The newly generated dimension that is sized for the number of nonzero elements.
-
Returns:
A symbolic tensor of indices
Cumulative operations
cumsum()
max.graph.ops.cumsum(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis: int = -1, exclusive: bool = False, reverse: bool = False) → TensorValue
Computes the cumulative sum of the input tensor along the given axis.
-
Parameters:
- x – The input tensor to sum over.
- axis – The axis along which to compute the sum. If negative, indexes from the last dimension. For example, a value of -1 will compute the sum along the last dimension.
- exclusive – If set, start at 0 and exclude the final element. Otherwise, start with the first element. Said another way, cumsum computes [sum(x[…, :i, …]) for i in range(x.shape[axis])]. If exclusive is set, the bounds are instead range(1, x.shape[axis]).
- reverse – If set, start from the end. In other words, the first element will be the total sum, with each element following counting downwards; or [sum(x[…, i:, …]) for i in range(x.shape[axis])].
-
Returns:
A symbolic tensor representing the result of the cumsum operation. The tensor will have the same type as the input tensor. The computed values will be the cumulative sum of the values along the given axis, according to the specified parameters:
- if exclusive is set, the first value will be 0, and the last value will be excluded from the sum
- if reverse is set, the sum will be computed starting at the back of the axis back to the front, rather than front-to-back
Audio processing
hann_window()
max.graph.ops.hann_window(window_length: int, device: DeviceRef, periodic: bool = True, dtype: DType = DType.float32) → TensorValue
Calculate a Hann window for a given length.
Hann window function:
where N is window_length.
-
Parameters:
- window_length – The length of the window.
- device – The device to run the operation on.
- periodic – bool flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like stft(). hann_window(L, periodic=True) == hann_window(L + 1, periodic=False)[:-1])
- dtype – The desired data type of the output tensor.
-
Returns:
A 1-D tensor of size (window_length,) containing the window.
Slicing
chunk()
max.graph.ops.chunk(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, chunks: int, axis: int = 0) → list[max.graph.value.TensorValue]
Chunk the tensor into an exact number of chunks along the specified dim.
-
Parameters:
- x – The tensor to chunk.
- chunks – The number of chunks to split the tensor into. chunks must statically evenly divide x.shape[axis].
- axis – The axis to split the tensor along.
-
Returns:
A list of chunks tensors.
Example
>>> a = TensorValue([1, 2, 3, 4, 5])
>>> chunk(a, 2, 0)
[TensorValue([1, 2]), TensorValue([3, 4])]
>>> a = TensorValue([1, 2, 3, 4, 5])
>>> chunk(a, 2, 0)
[TensorValue([1, 2]), TensorValue([3, 4])]
concat()
max.graph.ops.concat(original_vals: Iterable[Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray], axis: int = 0) → TensorValue
Concatenates a list of symbolic tensors along an axis.
-
Parameters:
- original_vals – A list of symbolic tensor values. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than
axis
. - axis – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape. For instance,
concat(vs, -1)
will concat along the last dimension.
- original_vals – A list of symbolic tensor values. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than
-
Returns:
A new symbolic tensor representing the concatenation result. It will have the same rank as each input tensor, and its dimensions will be the same as each input tensor’s for each dimension other than axis, which will have size equal to the sum of all tensor’s size for that dimension.
gather()
max.graph.ops.gather(input: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, indices: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis: int = -1) → TensorValue
Selects elements out of an input tensor by index.
-
Parameters:
- input – The input symbolic tensor to select elements from.
- indices – A symbolic tensor of index values to use for selection.
- axis – The dimension which
indices
indexes frominput
. If negative, indexes relative to the end of the input tensor. For instance,gather(input, indices, axis=-1)
will index against the last dimension ofinput
.
-
Returns:
A new symbolic tensor representing the result of the gather operation.
masked_scatter()
max.graph.ops.masked_scatter(input: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, mask: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, updates: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Creates a new symbolic tensor where the updates are written to input where mask is true.
-
Parameters:
- input – The input symbolic tensor to write elements to.
- mask – A symbolic tensor of boolean values to update.
- updates – A symbolic tensor of elements to write to input.
-
Returns:
A new symbolic tensor representing the result of the masked_scatter operation.
scatter()
max.graph.ops.scatter(input: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, updates: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, indices: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, axis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray = -1) → TensorValue
Creates a new symbolic tensor where the updates are written to input according to indices.
-
Parameters:
- input – The input symbolic tensor to write elements to.
- updates – A symbolic tensor of elements to write to input.
- indices – The positions in input to update.
- axis – The axis along which indices indexes into.
-
Returns:
A new symbolic tensor representing the result of the scatter operation.
select()
max.graph.ops.select(cond: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, y: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Returns condition ? x : y
(element-wise), where cond
, x
and y
are input tensors.
-
Parameters:
- condition – The condition tensor to use for selecting elementwise values.
- x – If the condition is true at a position, the value from the same position in this tensor will be selected.
- y – If the condition is false at a position, the value from the same position in this tensor will be selected.
-
Returns:
A new symbolic tensor holding either values from either
x
ory
, based on the elements in condition.
split()
max.graph.ops.split(x: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray, split_sizes: Sequence[int | str | Dim | integer], axis: int = 0) → list[max.graph.value.TensorValue]
Splits the input tensor into multiple tensors along a given dimension.
-
Parameters:
- x – The input symbolic tensor to split.
- split_sizes – Sizes of each output tensor. Must add up to the split dimension axis.
- axis – Dimension to split the input tensor.
-
Returns:
A list of tensors with the same length as split_sizes, where each tensor has the same shape as the input except along the split dimension axis, where the size is given by the corresponding element in split_sizes.
stack()
max.graph.ops.stack(vals: Iterable[Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray], axis: int = 0) → TensorValue
Stacks a list of tensors along a new axis.
-
Parameters:
- values – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension.
- axis – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape plus 1. For instance,
stack(vs, -1)
will create and stack along a new axis as the last dimension, aadstack(vs, -2)
will create and stack along a new dimension which is inserted immediately before the last dimension.
-
Returns:
A new symbolic tensor representing the result of the stack. It will have rank
n+1
wheren
is the rank of each input tensor. Its size on each dimension other thanaxis
will be the same as each input tensors’, with the new axis inserted. Along the new dimension it will have sizelen(values)
.
Random operations
normal()
max.graph.ops.random.normal(like: TensorType, mean: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray = 0.0, std: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray = 1.0) → TensorValue
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!