Python module
ops
Implements ops used when staging a graph.
Although the following modules provide a lot of the ops you want when building a graph, you can also use functions in Graph to add constant values, such as constant(), vector(), and scalar().
The TensorValue type (returned by most ops) also implements various dunder methods to support operations between TensorValues, such as + add, * multiply, and @ matmul, plus convenience methods such as reshape() and swapaxes().
Casting
broadcast_to()
max.graph.ops.broadcast_to(x: TensorValue, shape: TensorValue | Iterable[int | str | Dim | integer], out_dims: Iterable[int | str | Dim | integer] | None = None) → TensorValue
Broadcasts a symbolic tensor.
Broadcasts the input tensor to the specified shape. Dimensions in the input must be one or match the target dimension.
-
Parameters:
- x – The input symbolic tensor to broadcast. This tensor may not contain any dynamic dimensions.
- shape – The new shape as a list of dimensions. Dynamic dimensions are not allowed.
- out_dims – Output dims used only for tensor-valued shape.
-
Returns:
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as
shape
. -
Raises:
ValueError – if a tensor-valued shape is passed without out_dims.
cast()
max.graph.ops.cast(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, dtype: DType) → TensorValue
Casts a symbolic tensor to a different data type.
-
Parameters:
- v – The input tensor to cast.
- dtype – The target dtype to which the tensor is cast.
-
Returns:
A new symbolic tensor with the same shape as the input and the specified dtype.
rebind()
max.graph.ops.rebind(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, shape: Iterable[int | str | Dim | integer], message: str = '') → TensorValue
Rebinds a symbolic tensor to a specified set of dimensions.
This does not mutate the symbolic tensor passed in, but instead adds a
runtime assert that the input symbolic shape is equivalent to
out_dims
shape. For example, if the input tensor shape has
dynamic/unknown sizes, this will assert a fixed sizes that may be required
for a subsequent operation.
-
Parameters:
- x – The input symbolic tensor to rebind.
- shape – The symbolic shape to assert for
x
, as a list ofDim
values. - message – The message printed if the rebind fails at runtime.
-
Returns:
A symbolic tensor with the same elements and shape as the given tensor, but with the symbolic shape asserted to
out_dims
.
reshape()
max.graph.ops.reshape(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, shape: Iterable[int | str | Dim | integer]) → TensorValue
Reshapes a symbolic tensor.
The number and order of the elements in the tensor is unchanged. In other words, if you were to iterate over elements in the tensor by major dimension to minor dimension, the iteration order would stay the same.
If a value of -1 is present in the shape, that dimension becomes an automatically calculated dimension collecting all unspecified dimensions. Its length becomes the number of elements in the original tensor divided by the product of elements of the reshape.
-
Parameters:
- x – The input symbolic tensor to reshape. This tensor may not contain any dynamic dimensions.
- shape – The new shape as a list of dimensions. Dynamic dimensions are not allowed. A single dimension may be -1.
-
Returns:
A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as
shape
. -
Raises:
ValueError – if input and target shapes’ number of elements mismatch.
shape_to_tensor()
max.graph.ops.shape_to_tensor(shape: Iterable[int | str | Dim | integer]) → TensorValue
Converts a shape to a tensor.
This is useful for using a shape attribute in an op that expects a tensor value.
-
Parameters:
shape – the shape attribute of a tensor value.
-
Returns:
The TensorValue containing the same value as shape.
Example
>>> x = ops.constant(np.zeros((1,)), DType.int64)
>>> result = ops.stack([
... x,
... ops.shape_to_tensor(x.shape),
... ])
TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)])
>>> x = ops.constant(np.zeros((1,)), DType.int64)
>>> result = ops.stack([
... x,
... ops.shape_to_tensor(x.shape),
... ])
TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)])
squeeze()
max.graph.ops.squeeze(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, axis: int) → TensorValue
Removes a size-1 dimension from a symbolic tensor.
-
Parameters:
- x – The input symbolic tensor to squeeze.
- axis – The dimension to remove from the input’s shape. If negative, this
indexes from the end of the tensor. For example,
squeeze(v, -1)
squeezes the last dimension.
-
Returns:
A symbolic tensor with the same number of elements as the input tensor, and whose rank is 1 less than the rank of the input tensor.
transpose()
max.graph.ops.transpose(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, dim_1: int, dim_2: int) → TensorValue
Transposes two dimensions of a symbolic tensor.
-
Parameters:
- input – The input symbolic tensor to transpose.
- x – One of the two dimensions to transpose. If negative, this indexes
from the end of the tensor. For example,
transpose(v, -1, -2)
transposes the last two dimensions. - y – The other dimension to transpose. May also be negative to index from the end of the tensor.
-
Returns:
A new symbolic tensor with the two specified dimensions transposed. It has the same elements and dtype, but the order of the elements is different according to the transposition.
unsqueeze()
max.graph.ops.unsqueeze(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, axis: int) → TensorValue
Inserts a size-1 dimension into a symbolic tensor.
-
Parameters:
- x – The input symbolic tensor to unsqueeze.
- axis – The index at which to insert a new dimension into the input’s
shape. Elements at that index or higher are shifted back.
If negative, it indexes relative 1 plus the rank of the tensor.
For example,
unsqueeze(v, -1)
adds a new dimension at the end, andunsqueeze(v, -2)
inserts the dimension immediately before the last dimension.
-
Returns:
A symbolic tensor with the same number of elements as the input tensor, whose rank is 1 larger than the rank of the input tensor. The result’s shape at the
axis
dimension is a static dimension of size 1.
Complex
as_interleaved_complex()
max.graph.ops.as_interleaved_complex(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Reshapes the input symbolic tensor as complex from alternating (real, imag).
-
Parameters:
interleaved – A symbolic tensor representing complex numbers as alternating pairs of (real, imag) real-valued numbers. Its last dimension must have an even size.
-
Returns:
A symbolic tensor representing the complex-valued tensor, but with the values pulled out as complex numbers. The result has the same dimensions for all dimensions except the last dimension, which is halved, and then a final dimension of size 2 representing the complex value.
Constant
constant()
max.graph.ops.constant(value: ndarray | int | float | integer | floating, dtype: DType) → TensorValue
Adds a node representing a constant operation.
The value of this constant will have the type TensorType with the same shape as value. If value is a scalar type, it will create a TensorType with 0 dimensions.
The constant will be loaded with the specified dtype. If the constant does not fit within the specified dtype, an error is raised.
Warning: Loading the constant could result in precision loss. For example, loading 16777217 as a float32 will result in 16777216.0.
-
Parameters:
- value – The constant’s value.
- dtype – The constant tensor’s element type.
-
Returns:
A graph value containing the constant data as an attribute.
Elementwise
An elementwise operation performs the same calculation on each element of an input tensor. These operations take tensors of compatible shapes and apply the specified operation to each element pair.
For example, the following demonstrates how to add two tensors using the add()
function:
import numpy as np
from max import engine
from max.dtype import DType
from max.graph import Graph, TensorType, ops
def main():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("simple_add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First operand
y = graph.inputs[1] # Second addend
out = ops.add(x, y)
graph.output(out)
session = engine.InferenceSession()
model = session.load(graph)
input_0 = np.array([10.0, 8.0], dtype=np.float32)
input_1 = np.array([2.0, 4.0], dtype=np.float32)
ret = model.execute(input_0, input_1)
print("\nAddition computation:")
print("Result ", ret["output0"])
if __name__ == "__main__":
main()
import numpy as np
from max import engine
from max.dtype import DType
from max.graph import Graph, TensorType, ops
def main():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("simple_add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First operand
y = graph.inputs[1] # Second addend
out = ops.add(x, y)
graph.output(out)
session = engine.InferenceSession()
model = session.load(graph)
input_0 = np.array([10.0, 8.0], dtype=np.float32)
input_1 = np.array([2.0, 4.0], dtype=np.float32)
ret = model.execute(input_0, input_1)
print("\nAddition computation:")
print("Result ", ret["output0"])
if __name__ == "__main__":
main()
abs()
max.graph.ops.abs(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise absolute value of a symbolic tensor.
Creates a new op node to compute the elementwise absolute value of a symbolic tensor and adds it to the graph, returning the symbolic result.
The following demonstrates how to compute the absolute value using the abs()
function:
def abs_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("abs_graph", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.abs(x)
graph.output(out)
def abs_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("abs_graph", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.abs(x)
graph.output(out)
-
Parameters:
value – The symbolic tensor to use as the input to the absolute value computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
add()
max.graph.ops.add(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Adds two symbolic tensors.
Creates a new op node to compute the addition of two symbol tensor values and adds it to the graph, returning the symbolic result.
The following shows an example of the add() function with two inputs:
def add_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0]
y = graph.inputs[1]
out = ops.add(x, y)
graph.output(out)
def add_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("add_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0]
y = graph.inputs[1]
out = ops.add(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according : to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the : same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
- location – An optional location for a more specific error message.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
cos()
max.graph.ops.cos(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise cosine of a symbolic tensor.
Creates a new op node to compute the elementwise cosine of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the cos computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
div()
max.graph.ops.div(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Divides two symbolic tensors.
Creates a new op node to compute the division of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according : to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the : same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
- location – An optional location for a more specific error message.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
equal()
max.graph.ops.equal(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise equality comparison between two symbolic tensors.
Creates a new op node to compute the equality comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
def equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(3,))
with Graph("equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First input
y = graph.inputs[1] # Second input
out = ops.equal(x, y)
graph.output(out)
def equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(3,))
with Graph("equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # First input
y = graph.inputs[1] # Second input
out = ops.equal(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
erf()
max.graph.ops.erf(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise error function of a symbolic tensor.
Creates a new op node to compute the elementwise error function of a symbolic tensor and adds it to the graph, returning the symbolic result.
The error function erf
is defined as the probability that a randomly
sampled normal distribution falls within a given range.
-
Parameters:
value – The symbolic tensor to use as the input to the error function computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
exp()
max.graph.ops.exp(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise exp function of a symbolic tensor.
Creates a new op node to compute the elementwise exp function of a symbolic tensor and adds it to the graph, returning the symbolic result.
exp
is defined as exp(x) = e^x
, where e
is Euler’s number.
-
Parameters:
value – The symbolic tensor to use as the input to the exp function computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
floor()
max.graph.ops.floor(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise floor of a symbolic tensor.
Creates a new op node to compute the elementwise floor of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the floor computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
gelu()
max.graph.ops.gelu(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise gelu function of a symbolic tensor.
Creates a new op node to compute the elementwise gelu function of a symbolic tensor and adds it to the graph, returning the symbolic result.
gelu
is defined as $$gelu(x) = x \Phi(x)$$
where $$\Phi$$
is the
cumulative distribution function of the Gaussian distribution.
-
Parameters:
value – The symbolic tensor to use as the input to the gelu function computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
greater()
max.graph.ops.greater(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise greater than comparison between two symbolic tensors.
Creates a new op node to compute the greater than comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
def greater_than_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("greater_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.greater(x, y)
graph.output(out)
def greater_than_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("greater_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.greater(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
greater_equal()
max.graph.ops.greater_equal(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise greater-or-equal comparison between two symbolic tensors.
Creates a new op node to compute the equality comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
is_inf()
max.graph.ops.is_inf(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise is_inf of a symbolic tensor.
Creates a new op node to compute the elementwise is_inf of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the is_inf computation.
-
Returns:
- element type
bool
, true if the element at a given position : is plus or minus infinity, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Raises – If the symbol doesn’t represent a tensor value.
is_nan()
max.graph.ops.is_nan(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise is_nan of a symbolic tensor.
Creates a new op node to compute the elementwise is_nan of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the is_nan computation.
-
Returns:
- element type
bool
, true if the element at a given position : is NaN, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
log()
max.graph.ops.log(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise natural logarithm of a symbolic tensor.
Creates a new op node to compute the elementwise natural logarithm of a symbolic tensor and adds it to the graph, returning the symbolic result.
The natural logarithm function log
is defined as the inverse of the
exponential function exp()
. In other words, it computes the value y
in
the equation x = e^y
where e
is Euler’s number.
log(x)
is undefined for x <= 0
for real numbers. Complex numbers
are currently unsupported.
-
Parameters:
value – The symbolic tensor to use as the input to the natural logarithm computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
log1p()
max.graph.ops.log1p(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise logarithm of 1 plus a symbolic tensor.
Creates a new op node to compute the elementwise log1p of a symbolic tensor and adds it to the graph, returning the symbolic result.
The log1p
function is defined as log1p(x) = log(1 + x)
, where log()
is the natural logarithm.
Using log1p(x)
rather than computing log(1 + x)
can give greater
numerical precision results.
log(x)
is undefined for x <= 0
for real numbers. Complex numbers
are currently unsupported.
-
Parameters:
value – The symbolic tensor to use as the input to the log1p computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
logical_not()
max.graph.ops.logical_not(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise logical_not of a symbolic tensor.
Creates a new op node to compute the elementwise logical_not of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the logical_not computation.
-
Returns:
- element type
bool
, true if the element at a given position : is plus or minus infinity, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
logsoftmax()
max.graph.ops.logsoftmax(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise logsoftmax of a symbolic tensor.
Creates a new op node to compute the elementwise logsoftmax of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the logsoftmax computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
max()
max.graph.ops.max(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise maximum of two symbolic tensors.
Creates a new op node to compute the maximum of two symbol tensor values and adds it to the graph, returning the symbolic result.
def maximum_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("maximum_graph", input_types=(input_type, input_type)) as graph:
out = ops.max(graph.inputs[0], graph.inputs[1])
graph.output(out)
def maximum_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("maximum_graph", input_types=(input_type, input_type)) as graph:
out = ops.max(graph.inputs[0], graph.inputs[1])
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
min()
max.graph.ops.min(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise minimum of two symbolic tensors.
Creates a new op node to compute the minimum of two symbol tensor values and adds it to the graph, returning the symbolic result.
def min_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("min_graph", input_types=(input_type, input_type)) as graph:
out = ops.min(graph.inputs[0], graph.inputs[1])
graph.output(out)
def min_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("min_graph", input_types=(input_type, input_type)) as graph:
out = ops.min(graph.inputs[0], graph.inputs[1])
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
mod()
max.graph.ops.mod(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise maximum of two symbolic tensors.
Creates a new op node to compute the maximum of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
mul()
max.graph.ops.mul(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise multiplication of two symbolic tensors.
Creates a new op node to compute the multiplication of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
negate()
max.graph.ops.negate(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise negate of a symbolic tensor.
Creates a new op node to compute the elementwise negate of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the negate computation.
-
Returns:
- element type
bool
, true if the element at a given position : is plus or minus infinity, false otherwise - the same shape as the input value.
- element type
-
Return type:
The result will have
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
not_equal()
max.graph.ops.not_equal(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise inequality comparison between two symbolic tensors.
Creates a new op node to compute the inequality comparison of two symbol tensor values and adds it to the graph, returning the symbolic result.
def not_equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("not_equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.not_equal(x, y)
graph.output(out)
def not_equal_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("not_equal_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Left hand side
y = graph.inputs[1] # Right hand side
out = ops.not_equal(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
pow()
max.graph.ops.pow(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise exponentiation of two symbolic tensors.
Creates a new op node to compute the exponentiation of two symbol tensor values and adds it to the graph, returning the symbolic result.
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
relu()
max.graph.ops.relu(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise relu of a symbolic tensor.
Creates a new op node to compute the elementwise relu of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the relu computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
round()
max.graph.ops.round(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise round of a symbolic tensor.
Creates a new op node to compute the elementwise round of a symbolic tensor and adds it to the graph, returning the symbolic result.
For example, if the model has one input tensor:
def round_graph():
input_type = TensorType(dtype=DType.float32, shape=(4,))
with Graph("round_graph_example", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.round(x)
graph.output(out)
def round_graph():
input_type = TensorType(dtype=DType.float32, shape=(4,))
with Graph("round_graph_example", input_types=(input_type,)) as graph:
x = graph.inputs[0]
out = ops.round(x)
graph.output(out)
-
Parameters:
value – The symbolic tensor to use as the input to the round computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
roundeven()
max.graph.ops.roundeven(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise roundeven of a symbolic tensor.
Creates a new op node to compute the elementwise roundeven of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the roundeven computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
rsqrt()
max.graph.ops.rsqrt(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise inverse-square-root of a symbolic tensor.
Creates a new op node to compute the elementwise rsqrt of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the rsqrt computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sigmoid()
max.graph.ops.sigmoid(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise sigmoid of a symbolic tensor.
Creates a new op node to compute the elementwise sigmoid of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the sigmoid computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
silu()
max.graph.ops.silu(x: TensorValue)
Computes the elementwise silu of a symbolic tensor.
Creates a new op node to compute the elementwise silu of a symbolic tensor and adds it to the graph, returning the symbolic result.
silu
is defined as silu(x) = x * sigmoid(x)
.
-
Parameters:
value – The symbolic tensor to use as the input to the silu computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sin()
max.graph.ops.sin(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise sine of a symbolic tensor.
Creates a new op node to compute the elementwise sine of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the sin computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
softmax()
max.graph.ops.softmax(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise softmax of a symbolic tensor.
Creates a new op node to compute the elementwise softmax of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the softmax computation.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sqrt()
max.graph.ops.sqrt(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise sqrt of a symbolic tensor.
Creates a new op node to compute the elementwise sqrt of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the sqrt computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
sub()
max.graph.ops.sub(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise subtraction of two symbolic tensors.
Creates a new op node to compute the subtraction of two symbol tensor values and adds it to the graph, returning the symbolic result.
def sub_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("sub_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Minuend (number being subtracted from)
y = graph.inputs[1] # Subtrahend (number being subtracted)
out = ops.sub(x, y)
graph.output(out)
def sub_graph():
input_type = TensorType(dtype=DType.float32, shape=(2,))
with Graph("sub_graph", input_types=(input_type, input_type)) as graph:
x = graph.inputs[0] # Minuend (number being subtracted from)
y = graph.inputs[1] # Subtrahend (number being subtracted)
out = ops.sub(x, y)
graph.output(out)
-
- If
lhs
andrhs
have different dtypes, they will be promoted according to the dtype promotion rules before the operation. - If
lhs
andrhs
have different shapes, they will be broadcast to the same shape according to broadcasting rules` before the operation.
- If
-
Parameters:
- lhs – The symbol to use as left side of the addition.
- rhs – The symbol to use as right side of the addition.
-
Returns:
A symbolic tensor value representing the output of the addition. The result will have:
- the same dtype as the type-promotion of the two input dtypes
- the same shape as the broadcast of the two input shapes.
-
Raises:
- Error – If the input values’ shapes are not compatible for broadcasting.
- Error – If one of the input values has an unsupported dtype.
- Error – If the two symbols are parts of different graphs.
tanh()
max.graph.ops.tanh(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise tanh of a symbolic tensor.
Creates a new op node to compute the elementwise tanh of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the tanh computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
trunc()
max.graph.ops.trunc(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the elementwise truncation of a symbolic tensor.
Creates a new op node to compute the elementwise truncation of a symbolic tensor and adds it to the graph, returning the symbolic result.
-
Parameters:
value – The symbolic tensor to use as the input to the truncation computation. If it’s not a floating-point DType, an exception will be raised.
-
Returns:
A new symbolic tensor value representing the output of the absolute : value computation.
-
Raises:
Error – If the symbol doesn’t represent a tensor value.
Linalg
band_part()
max.graph.ops.band_part(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, num_lower: int, num_upper: int, exclude: bool = False) → TensorValue
Masks out everything except a diagonal band of an input matrix.
Copies a tensor setting everything outside the central diagonal band of the matricies to zero, where all but the last two axes are effectively batches, and the last two axes define sub matricies.
Assumes the input has dimensions [I, J, …, M, N], then the output tensor has the same shape as the input, and the values are given by
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
with the indicator function:
in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))
in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))
-
Parameters:
- input – The input to mask out.
- num_lower – The number of diagonal bands to include below the central diagonal. If -1, include the entire lower triangle.
- num_upper – The number of diagonal bands to include above the central diagonal. If -1, include the entire upper triangle.
- exclude – If true, invert the selection of elements to mask. Elements in the band are set to zero.
-
Returns:
A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.
layer_norm()
max.graph.ops.layer_norm(input: TensorValue, gamma: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, beta: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, epsilon: float) → TensorValue
Performs layer normalization.
-
Parameters:
- input – The input tensor to normalize.
- gamma – The gamma parameter of the normalization.
- beta – The beta parameter of the normalization.
- epsilon – The epsilon parameter of the normalization.
-
Returns:
A graph tensor value with the normalization applied.
matmul()
max.graph.ops.matmul(lhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, rhs: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Computes the matrix multiplication of two tensor graph values.
Performs general matrix multiplication with broadcasting.
If the lhs is 1D, it will be reshaped to 1xD
.
If the rhs is 1D, it will be reshaped to Dx1
.
In both cases, the additional 1 dimensions will be removed from the
output shape.
For the multiplication, the innermost (rightmost) 2 dimensions are treated
as a matrix.
The lhs matrix will have the shape MxK
.
The rhs matrix will have the shape KxN
.
The output will have the shape MxN
The K
dimensions must be equivalent in both matrices.
The remaining outer dimensions will be broadcasted.
-
Parameters:
- lhs – The left-hand-side of the matmul.
- rhs – The right-hand-side of the matmul.
- location – An optional location for a more specific error message.
-
Returns:
A tensor graph value representing he result of broadcasting the two matricies together and then performing a matrix multiply along the innermost two dimension of each tensor.
Quantized
dequantize()
max.graph.ops.dequantize(encoding: QuantizationEncoding, quantized: TensorValue) → TensorValue
Dequantizes a quantized tensor to floating point.
NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.
-
Parameters:
- encoding – The quantization encoding to use.
- quantized – The quantized tensor to dequantize.
-
Returns:
The dequantized result (a floating point tensor).
qmatmul()
max.graph.ops.qmatmul(encoding: QuantizationEncoding, lhs: TensorValue, rhs: TensorValue) → TensorValue
Performs matrix multiplication between floating point and quantized tensors.
This quantizes the lhs
floating point value to match the encoding of the
rhs
quantized value, performs matmul, and then dequantizes the result.
Beware that, compared to a regular matmul op, this one expects the rhs
value to be transposed. For example, if the lhs
shape is [32, 64], and
the quantized rhs
shape is also [32, 64]
, then the output shape is
[32, 32]
.
That is, this function returns the result from:
dequantize(quantize(lhs) @ transpose(rhs))
The last two dimensions in lhs
are treated as matrices and multiplied
by rhs
(which must be a 2D tensor). Any remaining dimensions in lhs
are broadcast dimensions.
NOTE: Currently this supports Q4_0, Q4_K, and Q6_K encodings only.
-
Parameters:
- encoding – The quantization encoding to use.
- lhs – The non-quantized, left-hand-side of the matmul.
- rhs – The transposed and quantized right-hand-side of the matmul. Must be rank 2 (a 2D tensor/matrix) and in a supported quantization encoding.
-
Returns:
The dequantized result (a floating point tensor).
Reduction
mean()
max.graph.ops.mean(x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, axis=-1) → TensorValue
Reduces a symbolic tensor using a mean operation.
-
Parameters:
- x – The input tensor for the operation.
- axis – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension.
-
Returns:
A symbolic tensor representing the result of the mean operation. The tensor will have the same rank as the input tensor, and the same shape except along the
axis
dimension which will have size 1.
Slicing
concat()
max.graph.ops.concat(original_vals: Iterable[Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray], axis: int = 0) → TensorValue
Concatenates a list of symbolic tensors along an axis.
-
Parameters:
- original_vals – A list of symbolic tensor values. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than
axis
. - axis – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape. For instance,
concat(vs, -1)
will concat along the last dimension.
- original_vals – A list of symbolic tensor values. Each tensor must have the same
dtype and rank, and must have the same dimension size for each
dimension other than
-
Returns:
A new symbolic tensor representing the concatenation result. It will have the same rank as each input tensor, and its dimensions will be the same as each input tensor’s for each dimension other than axis, which will have size equal to the sum of all tensor’s size for that dimension.
gather()
max.graph.ops.gather(input: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, indices: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, axis: int = -1) → TensorValue
Selects elements out of an input tensor by index.
-
Parameters:
- input – The input symbolic tensor to select elements from.
- indices – A symbolic tensor of index values to use for selection.
- axis – The dimension which
indices
indexes frominput
. If negative, indexes relative to the end of the input tensor. For instance,gather(input, indices, axis=-1)
will index against the last dimension ofinput
.
-
Returns:
A new symbolic tensor representing the result of the gather operation.
select()
max.graph.ops.select(cond: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, x: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray, y: Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray) → TensorValue
Returns condition ? x : y
(element-wise), where cond
, x
and y
are input tensors.
-
Parameters:
- condition – The condition tensor to use for selecting elementwise values.
- x – If the condition is true at a position, the value from the same position in this tensor will be selected.
- y – If the condition is false at a position, the value from the same position in this tensor will be selected.
-
Returns:
A new symbolic tensor holding either values from either
x
ory
, based on the elements in condition.
stack()
max.graph.ops.stack(vals: Iterable[Value | TensorValue | Weight | Shape | Dim | int | float | integer | floating | ndarray], axis: int = 0) → TensorValue
Stacks a list of tensors along a new axis.
-
Parameters:
- values – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension.
- axis – The axis to concatenate along. If negative, indexes relative
to the end of the tensor shape plus 1. For instance,
stack(vs, -1)
will create and stack along a new axis as the last dimension, aadstack(vs, -2)
will create and stack along a new dimension which is inserted immediately before the last dimension.
-
Returns:
A new symbolic tensor representing the result of the stack. It will have rank
n+1
wheren
is the rank of each input tensor. Its size on each dimension other thanaxis
will be the same as each input tensors’, with the new axis inserted. Along the new dimension it will have sizelen(values)
.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!