# MAX Python API Documentation > The MAX Python API reference. This file contains all documentation content in a single document following the llmtxt.org standard. ## driver Exposes APIs for interacting with hardware, such as allocating tensors on a GPU and moving tensors between the CPU and GPU. It provides interfaces for memory management, device properties, and hardware monitoring. Through these APIs, you can control data placement, track resource utilization, and configure device settings for optimal performance. For example, you can use the following code to use an accelerator if one is available, otherwise use the CPU: ```python from max import driver device = driver.CPU() if driver.accelerator_count() == 0 else driver.Accelerator() print(f"Using {device} device") ``` ## `Accelerator` {#max.driver.Accelerator} > class max.driver.Accelerator(self, id: [int](https://docs.python.org/3/library/functions.html#int) = -1) Creates an accelerator device with the specified ID. Provides access to GPU or other hardware accelerators in the system. ```python from max import driver device = driver.Accelerator() # Or specify GPU id device = driver.Accelerator(id=0) # First GPU device = driver.Accelerator(id=1) # Second GPU # Get device id device_id = device.id ``` **Parameters:** **id** ([`int`](https://docs.python.org/3/library/functions.html#int) `,` `optional` ) – The device ID to use. Defaults to -1, which selects the first available accelerator. **Returns:** A new Accelerator device object. **Return type:** [Accelerator](#max.driver.Accelerator) ## `CPU` {#max.driver.CPU} > class max.driver.CPU(self, id: [int](https://docs.python.org/3/library/functions.html#int) = -1) Creates a CPU device. ```python from max import driver # Create default CPU device device = driver.CPU() # Device id is always 0 for CPU devices device_id = device.id ``` **Parameters:** **id** ([`int`](https://docs.python.org/3/library/functions.html#int) `,` `optional` ) – The device ID to use. Defaults to -1. **Returns:** A new CPU device object. **Return type:** [CPU](#max.driver.CPU) ## `DLPackArray` {#max.driver.DLPackArray} > class max.driver.DLPackArray(\*args, \*\*kwargs) ## `Device` {#max.driver.Device} > class max.driver.Device ### `api` {#max.driver.Device.api} > property api Returns the API used to program the device. Possible values are: * `cpu` for host devices. * `cuda` for NVIDIA GPUs. * `hip` for AMD GPUs. ```python from max import driver device = driver.CPU() device.api ``` ### `can_access` {#max.driver.Device.can_access} > can\_access Checks if this device can directly access memory of another device. ```python from max import driver gpu0 = driver.Accelerator(id=0) gpu1 = driver.Accelerator(id=1) if gpu0.can_access(gpu1): print("GPU0 can directly access GPU1 memory.") ``` **Parameters:** **other** ([`Device`](#max.driver.Device) ) – The other device to check peer access against. **Returns:** True if peer access is possible, False otherwise. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `cpu` {#max.driver.Device.cpu} > cpu = \ ### `default_stream` {#max.driver.Device.default_stream} > property default\_stream Returns the default stream for this device. The default stream is initialized when the device object is created. **Returns:** The default execution stream for this device. **Return type:** DeviceStream ### `id` {#max.driver.Device.id} > property id Returns a zero-based device id. For a CPU device this is always 0. For GPU accelerators this is the id of the device relative to this host. Along with the `label`, an id can uniquely identify a device, e.g. `gpu:0`, `gpu:1`. ```python from max import driver device = driver.Accelerator() device_id = device.id ``` **Returns:** The device ID. **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `is_compatible` {#max.driver.Device.is_compatible} > property is\_compatible Returns whether this device is compatible with MAX. **Returns:** True if the device is compatible with MAX, False otherwise. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_host` {#max.driver.Device.is_host} > property is\_host Whether this device is the CPU (host) device. ```python from max import driver device = driver.CPU() device.is_host ``` ### `label` {#max.driver.Device.label} > property label Returns device label. Possible values are: * `cpu` for host devices. * `gpu` for accelerators. ```python from max import driver device = driver.CPU() device.label ``` ### `stats` {#max.driver.Device.stats} > property stats Returns utilization data for the device. ```python from max import driver device = driver.CPU() stats = device.stats ``` **Returns:** A dictionary containing device utilization statistics. **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict) ### `synchronize` {#max.driver.Device.synchronize} > synchronize Ensures all operations on this device complete before returning. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If any enqueued operations had an internal error. ## `DeviceSpec` {#max.driver.DeviceSpec} > class max.driver.DeviceSpec(id, device\_type='cpu') Specification for a device, containing its ID and type. This class provides a way to specify device parameters like ID and type (CPU/GPU) for creating Device instances. **Parameters:** * **id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **device\_type** ([`Literal`](https://docs.python.org/3/library/typing.html#typing.Literal) `[` `'cpu'` `,` `'gpu'` `]` ) ### `accelerator()` {#max.driver.DeviceSpec.accelerator} > static accelerator(id=0) Creates an accelerator (GPU) device specification. **Parameters:** **id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `cpu()` {#max.driver.DeviceSpec.cpu} > static cpu(id=-1) Creates a CPU device specification. **Parameters:** **id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `device_type` {#max.driver.DeviceSpec.device_type} > device\_type: [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['cpu', 'gpu'] = 'cpu' Type of specified device. ### `id` {#max.driver.DeviceSpec.id} > id: [int](https://docs.python.org/3/library/functions.html#int) Provided id for this device. ## `Tensor` {#max.driver.Tensor} > class max.driver.Tensor(self, dtype: [max.\_core.dtype.DType](dtype.md#max.dtype.DType), shape: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)], device: [max.\_core.driver.Device](#max.driver.Device) | [None](https://docs.python.org/3/library/constants.html#None) = None, pinned: [bool](https://docs.python.org/3/library/functions.html#bool) = False) > class max.driver.Tensor(self, dtype: [max.\_core.dtype.DType](dtype.md#max.dtype.DType), shape: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)], stream: max.\_core.driver.DeviceStream, pinned: [bool](https://docs.python.org/3/library/functions.html#bool) = False) > class max.driver.Tensor(self, shape: ndarray\[writable=False], device: max.\_core.driver.Device) > class max.driver.Tensor(self, other: [max.\_core.driver.Tensor](#max.driver.Tensor)) Device-resident tensor representation. Allocates memory onto a given device with the provided shape and dtype. Tensors can be sliced to provide strided views of the underlying memory, but any tensors input into model execution must be contiguous. Supports numpy-style slicing but does not currently support setting items across multiple indices. ```python from max import driver from max.dtype import DType # Create a tensor on CPU cpu_tensor = driver.Tensor(shape=[2, 3], dtype=DType.float32) # Create a tensor on GPU gpu = driver.Accelerator() gpu_tensor = driver.Tensor(shape=[2, 3], dtype=DType.float32, device=gpu) ``` **Parameters:** * **dtype** ([`DType`](dtype.md#max.dtype.DType) ) – Data type of tensor elements. * **shape** (`Sequence` `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Tuple of positive, non-zero integers denoting the tensor shape. * **device** ([`Device`](#max.driver.Device) `,` `optional` ) – Device to allocate tensor onto. Defaults to the CPU. * **pinned** ([`bool`](https://docs.python.org/3/library/functions.html#bool) `,` `optional` ) – If True, memory is page-locked (pinned). Defaults to False. * **stream** (`DeviceStream` `,` `optional` ) – Stream to associate the tensor with. Overloaded function. 1. `__init__(self, dtype: max._core.dtype.DType, shape: collections.abc.Sequence[int], device: max._core.driver.Device | None = None, pinned: bool = False) -> None` 2. `__init__(self, dtype: max._core.dtype.DType, shape: collections.abc.Sequence[int], stream: max._core.driver.DeviceStream, pinned: bool = False) -> None` 3. `__init__(self, shape: ndarray[writable=False], device: max._core.driver.Device) -> None` 4. `__init__(self, other: max._core.driver.Tensor) -> None` > Moves the internals from an existing Tensor object into a new Tensor object. > Primarily used for initializing subclasses with existing Tensors. ### `contiguous()` {#max.driver.Tensor.contiguous} > contiguous() Creates a contiguous copy of the parent tensor. **Return type:** [*Tensor*](#max.driver.Tensor) ### `copy` {#max.driver.Tensor.copy} > copy Overloaded function. 1. `copy(self, stream: max._core.driver.DeviceStream) -> max._core.driver.Tensor` > Creates a deep copy on the device associated with the stream. > Args: > : stream (DeviceStream): The stream to associate the new tensor with. > Returns: > : Tensor: A new tensor that is a copy of this tensor. 2. `copy(self, device: max._core.driver.Device | None = None) -> max._core.driver.Tensor` > Creates a deep copy on an optionally given device. > If device is None (default), a copy is created on the same device. > > ```python > from max import driver > from max.dtype import DType > ​ > cpu_tensor = driver.Tensor(shape=[2, 3], dtype=DType.bfloat16, device=driver.CPU()) > cpu_copy = cpu_tensor.copy() > ​ > # Copy to GPU > gpu = driver.Accelerator() > gpu_copy = cpu_tensor.copy(device=gpu) > ``` > Args: > : device (Device, optional): The device to create the copy on. > : Defaults to None (same device). > Returns: > : Tensor: A new tensor that is a copy of this tensor. ### `device` {#max.driver.Tensor.device} > property device Device on which tensor is resident. ### `dtype` {#max.driver.Tensor.dtype} > property dtype DType of constituent elements in tensor. ### `element_size` {#max.driver.Tensor.element_size} > property element\_size Return the size of the element type in bytes. ### `from_dlpack()` {#max.driver.Tensor.from_dlpack} > from\_dlpack(\*, copy=None) Create a tensor from an object implementing the dlpack protocol. This usually does not result in a copy, and the producer of the object retains ownership of the underlying memory. **Parameters:** * **array** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) * **copy** ([`bool`](https://docs.python.org/3/library/functions.html#bool) `|` `None` ) **Return type:** [*Tensor*](#max.driver.Tensor) ### `from_numpy()` {#max.driver.Tensor.from_numpy} > from\_numpy() Creates a tensor from a provided numpy array on the host device. The underlying data is not copied unless the array is noncontiguous. If it is, a contiguous copy will be returned. **Parameters:** **arr** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*Tensor*](#max.driver.Tensor) ### `inplace_copy_from()` {#max.driver.Tensor.inplace_copy_from} > inplace\_copy\_from(src) Copy the contents of another tensor into this one. These tensors may be on different devices. Requires that both tensors are contiguous and have same size. **Parameters:** **src** ([`Tensor`](#max.driver.Tensor) ) **Return type:** None ### `is_contiguous` {#max.driver.Tensor.is_contiguous} > property is\_contiguous Whether or not tensor is contiguously allocated in memory. Returns false if the tensor is a non-contiguous slice. Currently, we consider certain situations that are contiguous as non-contiguous for the purposes of our engine, such as when a tensor has negative steps. ### `is_host` {#max.driver.Tensor.is_host} > property is\_host Whether or not tensor is host-resident. Returns false for GPU tensors, true for CPU tensors. ```python from max import driver from max.dtype import DType cpu_tensor = driver.Tensor(shape=[2, 3], dtype=DType.bfloat16, device=driver.CPU()) print(cpu_tensor.is_host) ``` ### `item` {#max.driver.Tensor.item} > item Returns the scalar value at a given location. Currently implemented only for zero-rank tensors. The return type is converted to a Python built-in type. ### `mmap()` {#max.driver.Tensor.mmap} > mmap(dtype, shape, mode='copyonwrite', offset=0) **Parameters:** * **filename** (`PathLike` `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **dtype** ([`DType`](dtype.md#max.dtype.DType) ) * **shape** (`ShapeType` `|` [`int`](https://docs.python.org/3/library/functions.html#int) ) * **mode** (`np._MemMapModeKind` ) ### `num_elements` {#max.driver.Tensor.num_elements} > property num\_elements Returns the number of elements in this tensor. Rank-0 tensors have 1 element by convention. ### `pinned` {#max.driver.Tensor.pinned} > property pinned Whether or not the underlying memory is pinned (page-locked). ### `rank` {#max.driver.Tensor.rank} > property rank Tensor rank. ### `scalar` {#max.driver.Tensor.scalar} > scalar = \ ### `shape` {#max.driver.Tensor.shape} > property shape Shape of tensor. ### `stream` {#max.driver.Tensor.stream} > property stream Stream to which tensor is bound. ### `to` {#max.driver.Tensor.to} > to Overloaded function. 1. `to(self, device: max._core.driver.Device) -> Tensor` > Return a tensor that’s guaranteed to be on the given device. > The tensor is only copied if the requested device is different from the > device upon which the tensor is already resident. 2. `to(self, device: max._core.driver.DeviceStream) -> Tensor` > Return a tensor that’s guaranteed to be on the given device and associated > with the given stream. > The tensor is only copied if the requested device is different from the > device upon which the tensor is already resident. ### `to_numpy()` {#max.driver.Tensor.to_numpy} > to\_numpy() Converts the tensor to a numpy array. If the tensor is not on the host, an exception is raised. **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `view()` {#max.driver.Tensor.view} > view(dtype, shape=None) Return a new tensor with the given type and shape that shares the underlying memory. If the shape is not given, it will be deduced if possible, or a ValueError is raised. **Parameters:** * **dtype** ([`DType`](dtype.md#max.dtype.DType) ) * **shape** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` `None` ) **Return type:** [*Tensor*](#max.driver.Tensor) ### `zeros` {#max.driver.Tensor.zeros} > zeros = \ ## `accelerator_api()` {#max.driver.accelerator_api} > max.driver.accelerator\_api() Returns the API used to program the accelerator. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ## `devices_exist()` {#max.driver.devices_exist} > max.driver.devices\_exist(devices) Identify if devices exist. **Parameters:** **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceSpec`](#max.driver.DeviceSpec) `]` ) **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ## `load_devices()` {#max.driver.load_devices} > max.driver.load\_devices(device\_specs) Initialize and return a list of devices, given a list of device specs. **Parameters:** **device\_specs** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceSpec`](#max.driver.DeviceSpec) `]` ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Device*](#max.driver.Device)] ## `scan_available_devices()` {#max.driver.scan_available_devices} > max.driver.scan\_available\_devices() Returns all accelerators if available, else return cpu. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*DeviceSpec*](#max.driver.DeviceSpec)] --- ## dtype Provides data type definitions for tensors in MAX Engine. These data types are essential for defining the precision and memory layout of tensor data when working with machine learning models. This module defines the [`DType`](#max.dtype.DType) enum, which represents all supported tensor data types in MAX Engine, including: * Integer types (signed and unsigned): `int8` | `uint8` | `int16` | `uint16` | `int32` | `uint32` | `int64` | `uint64` * Floating-point types: `float8` variants | `float16` | `bfloat16` | `float32` | `float64` * Boolean type The module also provides utilities for converting between MAX Engine data types and [NumPy dtypes](https://numpy.org/doc/stable/user/basics.types.html), making it easy to interoperate with the NumPy ecosystem. ```python import numpy as np from max.dtype import DType tensor = np.zeros((2, 3), dtype=DType.float32.to_numpy()) # Convert NumPy dtype to MAX DType array = np.ones((4, 4), dtype=np.float16) max_dtype = DType.from_numpy(array.dtype) # Check properties of data types is_float = DType.float32.is_float() # True is_int = DType.int64.is_integral() # True size = DType.float64.size_in_bytes # 8 ``` ## `DType` {#max.dtype.DType} > class max.dtype.DType(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) The tensor data type. ### `align` {#max.dtype.DType.align} > property align Returns the alignment requirement of the data type in bytes. The alignment specifies the memory boundary that values of this data type must be aligned to for optimal performance and correctness. ### `bfloat16` {#max.dtype.DType.bfloat16} > bfloat16 = 71 ### `bool` {#max.dtype.DType.bool} > bool = 1 ### `float16` {#max.dtype.DType.float16} > float16 = 70 ### `float32` {#max.dtype.DType.float32} > float32 = 72 ### `float64` {#max.dtype.DType.float64} > float64 = 73 ### `float8_e4m3fn` {#max.dtype.DType.float8_e4m3fn} > float8\_e4m3fn = 66 ### `float8_e4m3fnuz` {#max.dtype.DType.float8_e4m3fnuz} > float8\_e4m3fnuz = 67 ### `float8_e5m2` {#max.dtype.DType.float8_e5m2} > float8\_e5m2 = 68 ### `float8_e5m2fnuz` {#max.dtype.DType.float8_e5m2fnuz} > float8\_e5m2fnuz = 69 ### `from_numpy()` {#max.dtype.DType.from_numpy} > from\_numpy() Converts a NumPy dtype to the corresponding DType. **Parameters:** **dtype** (`np.dtype` ) – The NumPy dtype to convert. **Returns:** The corresponding DType enum value. **Return type:** [DType](#max.dtype.DType) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input dtype is not supported. ### `from_torch()` {#max.dtype.DType.from_torch} > from\_torch() **Parameters:** **dtype** (`dtype` ) **Return type:** [*DType*](#max.dtype.DType) ### `int16` {#max.dtype.DType.int16} > int16 = 137 ### `int32` {#max.dtype.DType.int32} > int32 = 139 ### `int64` {#max.dtype.DType.int64} > int64 = 141 ### `int8` {#max.dtype.DType.int8} > int8 = 135 ### `is_float` {#max.dtype.DType.is_float} > is\_float Checks if the data type is a floating-point type. ### `is_float8` {#max.dtype.DType.is_float8} > is\_float8 Checks if the data type is an 8-bit floating-point type. ### `is_half` {#max.dtype.DType.is_half} > is\_half Checks if the data type is a half-precision floating-point type. ### `is_integral` {#max.dtype.DType.is_integral} > is\_integral Checks if the data type is an integer type. ### `is_signed_integral` {#max.dtype.DType.is_signed_integral} > is\_signed\_integral Checks if the data type is a signed integer type. ### `is_unsigned_integral` {#max.dtype.DType.is_unsigned_integral} > is\_unsigned\_integral Checks if the data type is an unsigned integer type. ### `size_in_bytes` {#max.dtype.DType.size_in_bytes} > property size\_in\_bytes Returns the size of the data type in bytes. This indicates how many bytes are required to store a single value of this data type in memory. ### `to_numpy()` {#max.dtype.DType.to_numpy} > to\_numpy() Converts this `DType` to the corresponding NumPy dtype. **Returns:** The corresponding NumPy dtype object. **Return type:** [DType](#max.dtype.DType) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the dtype is not supported. ### `to_torch()` {#max.dtype.DType.to_torch} > to\_torch() **Parameters:** **dtype** ([`DType`](#max.dtype.DType) ) **Return type:** *dtype* ### `uint16` {#max.dtype.DType.uint16} > uint16 = 136 ### `uint32` {#max.dtype.DType.uint32} > uint32 = 138 ### `uint64` {#max.dtype.DType.uint64} > uint64 = 140 ### `uint8` {#max.dtype.DType.uint8} > uint8 = 134 --- ## engine The APIs in this module allow you to run inference with MAX Engine—a graph compiler and runtime that accelerates your AI models on a wide variety of hardware. ## `InferenceSession` {#max.engine.InferenceSession} > class max.engine.InferenceSession(num\_threads=None, devices=None, \*, custom\_extensions=None) Manages an inference session in which you can load and run models. You need an instance of this to load a model as a [`Model`](#max.engine.Model) object. For example: ```python session = engine.InferenceSession() model_path = Path('bert-base-uncased') model = session.load(model_path) ``` **Parameters:** * **num\_threads** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – Number of threads to use for the inference session. This defaults to the number of physical cores on your machine. * **devices** (`Iterable` `[` [`Device`](driver.md#max.driver.Device) `]` `|` `None` ) – A list of devices on which to run inference. Default is the host CPU only. * **custom\_extensions** (`CustomExtensionsType` `|` `None` ) – The extensions to load for the model. Supports paths to a .mojopkg custom ops library or a .mojo source file. ### `devices` {#max.engine.InferenceSession.devices} > property devices: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Device](driver.md#max.driver.Device)] A list of available devices. ### `gpu_profiling()` {#max.engine.InferenceSession.gpu_profiling} > gpu\_profiling(mode) Enables end to end gpu profiling configuration. **Parameters:** **mode** ([`GPUProfilingMode`](#max.engine.GPUProfilingMode) ) ### `load()` {#max.engine.InferenceSession.load} > load(model, \*, custom\_extensions=None, custom\_ops\_path=None, weights\_registry=None) Loads a trained model and compiles it for inference. **Parameters:** * **model** (`Union` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `Path` `,` `Any` `]` ) – Path to a model. * **custom\_extensions** (`CustomExtensionsType` `|` `None` ) – The extensions to load for the model. Supports paths to .mojopkg custom ops. * **custom\_ops\_path** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) – The path to your custom ops Mojo package. Deprecated, use `custom_extensions` instead. * **weights\_registry** (`Mapping` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `DLPackCompatible` `]` `|` `None` ) – A mapping from names of model weights’ names to their values. The values are currently expected to be dlpack arrays. If an array is a read-only numpy array, the user must ensure that its lifetime extends beyond the lifetime of the model. **Returns:** The loaded model, compiled and ready to execute. **Raises:** [**RuntimeError**](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the path provided is invalid. **Return type:** [Model](#max.engine.Model) ### `set_mojo_assert_level()` {#max.engine.InferenceSession.set_mojo_assert_level} > set\_mojo\_assert\_level(level) Sets which mojo asserts are kept in the compiled model. **Parameters:** **level** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `AssertLevel` ) ### `set_mojo_log_level()` {#max.engine.InferenceSession.set_mojo_log_level} > set\_mojo\_log\_level(level) Sets the verbosity of mojo logging in the compiled model. **Parameters:** **level** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`LogLevel`](#max.engine.LogLevel) ) ### `set_split_k_reduction_precision()` {#max.engine.InferenceSession.set_split_k_reduction_precision} > set\_split\_k\_reduction\_precision(precision) Sets the accumulation precision for split k reductions in large matmuls. **Parameters:** **precision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `SplitKReductionPrecision` ) ## `Model` {#max.engine.Model} > class max.engine.Model A loaded model that you can execute. Do not instantiate this class directly. Instead, create it with [`InferenceSession`](#max.engine.InferenceSession). ### `__call__()` {#max.engine.Model.__call} > \_\_call\_\_(\*args, \*\*kwargs) Call self as a function. **Parameters:** * **self** ([`Model`](#max.engine.Model) ) * **args** ([`DLPackArray`](driver.md#max.driver.DLPackArray) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `[` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `,` [`dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype) `[` `\_ScalarType_co` `]` `]` `|` [`Tensor`](driver.md#max.driver.Tensor) `|` [`MojoValue`](#max.engine.MojoValue) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`bool`](https://docs.python.org/3/library/functions.html#bool) `|` [`generic`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic) ) * **kwargs** ([`DLPackArray`](driver.md#max.driver.DLPackArray) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `[` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `,` [`dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype) `[` `\_ScalarType_co` `]` `]` `|` [`Tensor`](driver.md#max.driver.Tensor) `|` [`MojoValue`](#max.engine.MojoValue) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`bool`](https://docs.python.org/3/library/functions.html#bool) `|` [`generic`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Tensor*](driver.md#max.driver.Tensor) | [*MojoValue*](#max.engine.MojoValue)] ### `execute()` {#max.engine.Model.execute} > execute(\*args) **Parameters:** * **self** ([`Model`](#max.engine.Model) ) * **args** ([`DLPackArray`](driver.md#max.driver.DLPackArray) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `[` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `,` [`dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype) `[` `\_ScalarType_co` `]` `]` `|` [`Tensor`](driver.md#max.driver.Tensor) `|` [`MojoValue`](#max.engine.MojoValue) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`bool`](https://docs.python.org/3/library/functions.html#bool) `|` [`generic`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.generic) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Tensor*](driver.md#max.driver.Tensor) | [*MojoValue*](#max.engine.MojoValue)] ### `input_metadata` {#max.engine.Model.input_metadata} > property input\_metadata Metadata about the model’s input tensors, as a list of [`TensorSpec`](#max.engine.TensorSpec) objects. For example, you can print the input tensor names, shapes, and dtypes: ```python for tensor in model.input_metadata: print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}') ``` ### `output_metadata` {#max.engine.Model.output_metadata} > property output\_metadata Metadata about the model’s output tensors, as a list of [`TensorSpec`](#max.engine.TensorSpec) objects. For example, you can print the output tensor names, shapes, and dtypes: ```python for tensor in model.output_metadata: print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}') ``` ## `GPUProfilingMode` {#max.engine.GPUProfilingMode} > class max.engine.GPUProfilingMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) The supported modes for GPU profiling. ### `DETAILED` {#max.engine.GPUProfilingMode.DETAILED} > DETAILED = 'detailed' ### `OFF` {#max.engine.GPUProfilingMode.OFF} > OFF = 'off' ### `ON` {#max.engine.GPUProfilingMode.ON} > ON = 'on' ## `LogLevel` {#max.engine.LogLevel} > class max.engine.LogLevel(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Internal use. ### `CRITICAL` {#max.engine.LogLevel.CRITICAL} > CRITICAL = 'critical' ### `DEBUG` {#max.engine.LogLevel.DEBUG} > DEBUG = 'debug' ### `ERROR` {#max.engine.LogLevel.ERROR} > ERROR = 'error' ### `INFO` {#max.engine.LogLevel.INFO} > INFO = 'info' ### `NOTSET` {#max.engine.LogLevel.NOTSET} > NOTSET = 'notset' ### `WARNING` {#max.engine.LogLevel.WARNING} > WARNING = 'warning' ## `MojoValue` {#max.engine.MojoValue} > class max.engine.MojoValue This is work in progress and you should ignore it for now. ## `TensorSpec` {#max.engine.TensorSpec} > class max.engine.TensorSpec(self, shape: [collections.abc.Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None)] | [None](https://docs.python.org/3/library/constants.html#None), dtype: [max.\_core.dtype.DType](dtype.md#max.dtype.DType), name: [str](https://docs.python.org/3/library/stdtypes.html#str)) Defines the properties of a tensor, including its name, shape and data type. For usage examples, see [`Model.input_metadata`](#max.engine.Model.input_metadata). **Parameters:** * **shape** – The tensor shape. * **dtype** – The tensor data type. * **name** – The tensor name. ### `dtype` {#max.engine.TensorSpec.dtype} > property dtype A tensor data type. ### `name` {#max.engine.TensorSpec.name} > property name A tensor name. ### `shape` {#max.engine.TensorSpec.shape} > property shape The shape of the tensor as a list of integers. If a dimension size is unknown/dynamic (such as the batch size), its value is `None`. ## `CustomExtensionsType` {#max.engine.CustomExtensionsType} > max.engine.CustomExtensionsType alias of [`list`](https://docs.python.org/3/library/stdtypes.html#list)\[[`str`](https://docs.python.org/3/library/stdtypes.html#str) | [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) | [`Any`](https://docs.python.org/3/library/typing.html#typing.Any)] | [`str`](https://docs.python.org/3/library/stdtypes.html#str) | [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) | [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) --- ## entrypoints ## `LLM` {#max.entrypoints.llm.LLM} > class max.entrypoints.llm.LLM(pipeline\_config) A high level interface for interacting with LLMs. **Parameters:** **pipeline\_config** ([`PipelineConfig`](pipelines/config.md#max.pipelines.lib.config.PipelineConfig) ) ### `generate()` {#max.entrypoints.llm.LLM.generate} > generate(prompts, max\_new\_tokens=100, use\_tqdm=True) Generates text completions for the given prompts. This method is thread safe and may be used on the same LLM instance from multiple threads concurrently with no external synchronization. **Parameters:** * **prompts** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` ) – The input string or list of strings to generate completions for. * **max\_new\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The maximum number of tokens to generate in the response. * **use\_tqdm** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to display a progress bar during generation. **Returns:** A list of generated text completions corresponding to each input prompt. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If prompts is empty or contains invalid data. * [**RuntimeError**](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If the model fails to generate completions. **Return type:** [*Sequence*](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] --- ## BufferValue ## `BufferValue` {#max.graph.BufferValue} > class max.graph.BufferValue(value) Bases: [`Value`](Value.md#max.graph.Value)\[`BufferType`] Represents a mutable semantic tensor within a Graph. Initializes a [`BufferValue`](#max.graph.BufferValue) from another value. **Parameters:** **value** ([`Value`](Value.md#max.graph.Value) `|` `\_Value` `[` `mo.BufferType` `]` ) – The value to wrap, either an MLIR value of buffer type or another [`BufferValue`](#max.graph.BufferValue). ### `device` {#max.graph.BufferValue.device} > property device: [DeviceRef](type.md#max.graph.type.DeviceRef) Returns the device of the BufferValue. ### `dtype` {#max.graph.BufferValue.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) Returns the tensor data type. ### `from_mlir()` {#max.graph.BufferValue.from_mlir} > classmethod from\_mlir(value) Creates a [`BufferValue`](#max.graph.BufferValue) from an MLIR buffer value. **Parameters:** **value** (`Value` `[` `BufferType` `]` ) – The MLIR buffer value to wrap. **Return type:** [*BufferValue*](#max.graph.BufferValue) ### `print()` {#max.graph.BufferValue.print} > print(label='debug\_buffer') Prints detailed information about the buffer. **Parameters:** **label** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) ### `rank` {#max.graph.BufferValue.rank} > property rank: [int](https://docs.python.org/3/library/functions.html#int) Returns the rank (number of dims) of the buffer. ### `shape` {#max.graph.BufferValue.shape} > property shape: [Shape](type.md#max.graph.type.Shape) Returns the shape of the BufferValue. ### `type` {#max.graph.BufferValue.type} > property type: BufferType Returns the type of the [`BufferValue`](#max.graph.BufferValue) as a `BufferType`. --- ## Graph ## `Graph` {#max.graph.Graph} > class max.graph.Graph(name, forward=None, input\_types=(), path=None, \*args, custom\_extensions=\[], context=None, kernel\_library=None, module=None, \*\*kwargs) Represents a single MAX graph. A Graph is a callable routine in MAX Engine. Like functions, graphs have a name and signature. Unlike a function, which follows an imperative programming model, a Graph follows a dataflow programming model, using lazily-executed, parallel operations instead of sequential instructions. When you instantiate a graph, you must specify the input shapes as one or more `TensorType` values. Then, build a sequence of ops and set the graph output with [`output()`](#max.graph.Graph.output). For example: ```python from dataclasses import dataclass import numpy as np from max.dtype import DType from max.graph import Graph, TensorType, TensorValue, ops @dataclass class Linear: weight: np.ndarray bias: np.ndarray def __call__(self, x: TensorValue) -> TensorValue: weight_tensor = ops.constant(self.weight, dtype=DType.float32, device=DeviceRef.CPU()) bias_tensor = ops.constant(self.bias, dtype=DType.float32, device=DeviceRef.CPU()) return ops.matmul(x, weight_tensor) + bias_tensor linear_graph = Graph( "linear", Linear(np.ones((2, 2)), np.ones((2,))), input_types=[TensorType(DType.float32, (2,))] ) ``` You can’t call a Graph directly from Python. You must compile it and execute it with MAX Engine. For more detail, see the tutorial about how to [build a graph with MAX Graph](/max/tutorials/get-started-with-max-graph-in-python). When creating a graph, a global sequence of chains is initialized and stored in Graph.\_current\_chain. Every side-effecting op, e.g. buffer\_load, store\_buffer, load\_slice\_buffer, store\_slice\_buffer, will use the current chain to perform the op and and update Graph.\_current\_chain with a new chain. Currently, the input/output chains for mutable ops can be used at most once. The goal of this design choice is to prevent data races. **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – A name for the graph. * **forward** (`Optional` `[` `Callable` `]` ) – The sequence of graph ops for the forward pass (inference). * **input\_types** (`Iterable` `[` [`Type`](type.md#max.graph.type.Type) `]` ) – The data type(s) for the input tensor(s). * **path** (`Optional` `[` `Path` `]` ) – The path to a saved graph (internal use only). * **custom\_extensions** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `Path` `]` ) – The extensions to load for the model. Supports paths to .mojopkg or .mojo sources with custom ops. * **context** (`Optional` `[` `mlir.Context` `]` ) * **kernel\_library** (`Optional` `[` [`KernelLibrary`](KernelLibrary.md#max.graph.KernelLibrary) `]` ) * **module** (`Optional` `[` `mlir.Module` `]` ) ### `add_subgraph()` {#max.graph.Graph.add_subgraph} > add\_subgraph(name, forward=None, input\_types=(), path=None, custom\_extensions=\[]) **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **forward** ([`Callable`](https://docs.python.org/3/library/typing.html#typing.Callable) `|` `None` ) * **input\_types** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Type`](type.md#max.graph.type.Type) `]` ) * **path** ([`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) `|` `None` ) * **custom\_extensions** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) `]` ) **Return type:** [*Graph*](#max.graph.Graph) ### `add_weight()` {#max.graph.Graph.add_weight} > add\_weight(weight, force\_initial\_weight\_on\_host=True) Adds a weight to the graph. If the weight is in the graph already, return the existing value. **Parameters:** * **weight** ([`Weight`](Weight.md#max.graph.Weight) ) – The weight to add to the graph. * **force\_initial\_weight\_on\_host** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If true, then forces weights to initially be allocated on host before being moved to the indicated device. This is needed as a stop gap until we have a more fleshed out ownership model of external constants. **Returns:** A [`TensorValue`](TensorValue.md#max.graph.TensorValue) that contains this weight. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If a weight with the same name already exists in the graph. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `current` {#max.graph.Graph.current} > current ### `inputs` {#max.graph.Graph.inputs} > property inputs: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[Value](Value.md#max.graph.Value)] The input values of the graph. ### `kernel_libraries_paths` {#max.graph.Graph.kernel_libraries_paths} > property kernel\_libraries\_paths: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Path](https://docs.python.org/3/library/pathlib.html#pathlib.Path)] Returns the list of extra kernel libraries paths for the custom ops. ### `local_weights_and_chain()` {#max.graph.Graph.local_weights_and_chain} > local\_weights\_and\_chain() ### `output()` {#max.graph.Graph.output} > output(\*outputs) Sets the output nodes of the [`Graph`](#max.graph.Graph). **Parameters:** **outputs** ([`Value`](Value.md#max.graph.Value) ) **Return type:** None ### `output_types` {#max.graph.Graph.output_types} > property output\_types: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[Type](type.md#max.graph.type.Type)] View of the types of the graph output terminator. --- ## KernelLibrary ## `KernelLibrary` {#max.graph.KernelLibrary} > class max.graph.KernelLibrary(context, paths=\[]) Manages custom kernel libraries and operations for a graph. A kernel library provides access to custom operations and kernels that can be loaded from various sources including Mojo binary packages (`.mojopkg`) and Mojo source directories. The library handles verification and registration of custom operations within the MLIR context. **Parameters:** * **context** (`mlir.Context` ) * **paths** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `Path` `]` ) ### `add_path()` {#max.graph.KernelLibrary.add_path} > add\_path(path) Adds a kernel library path to the analysis. **Parameters:** **path** ([`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) ) – The `Path` to the kernel library to be added to the current analysis. ### `library_paths()` {#max.graph.KernelLibrary.library_paths} > library\_paths() Returns the list of kernel library paths. **Returns:** A list of `Path` objects representing the currently loaded kernel library paths. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Path*](https://docs.python.org/3/library/pathlib.html#pathlib.Path)] ### `load_paths()` {#max.graph.KernelLibrary.load_paths} > load\_paths(context, custom\_extensions) Loads custom operations from provided library paths. Performs additional “smart” library loading logic for custom operation libraries in additional formats. The loading logic supports the following formats: * Compiled Mojo binary packages with `.mojopkg` extension * Mojo source directory with custom operations The loaded libraries are added to the current kernel library. **Parameters:** * **context** (`Context` ) – The MLIR context for loading MLIR operations. * **custom\_extensions** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) `]` ) – The file paths to the custom operation libraries. ### `verify_custom_op()` {#max.graph.KernelLibrary.verify_custom_op} > verify\_custom\_op(custom\_op) Verifies that a custom operation is valid within the current context. **Parameters:** **custom\_op** (`Operation` ) – The `mlir.Operation` to be verified against the current kernel library analysis. --- ## TensorValue ## `TensorValue` {#max.graph.TensorValue} > class max.graph.TensorValue(value) Bases: [`Value`](Value.md#max.graph.Value)\[`TensorType`] Represents a value semantic tensor within a [`Graph`](Graph.md#max.graph.Graph). It provides various methods and properties to manipulate and query tensor attributes such as [`shape`](#max.graph.TensorValue.shape), data type ([`dtype`](#max.graph.TensorValue.dtype)), device placement ([`device`](#max.graph.TensorValue.device)), and more. The following example demonstrates how to create and manipulate tensor values in a graph: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("tensor_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor properties print(f"Shape: {tensor.shape}") # Output: [2, 2] print(f"Data type: {tensor.dtype}") # Output: DType.float32 # Perform operations on the tensor transposed = tensor.T doubled = tensor * 2 print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Transposed shape: {transposed.shape}") # Output: [2, 2] ``` Initializes a [`TensorValue`](#max.graph.TensorValue) from a tensor-like value. **Parameters:** **value** (`TensorValueLike` ) – The value to wrap. Can be an MLIR tensor value, another [`TensorValue`](#max.graph.TensorValue), a `Dim`, or a `Shape`. ### `T` {#max.graph.TensorValue.T} > property T: [TensorValue](#max.graph.TensorValue) Returns the transposed tensor. [`T`](#max.graph.TensorValue.T) is the shorthand notation for transposing. For more information, see [`transpose()`](#max.graph.TensorValue.transpose). **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with swapped dimensions. ### `broadcast_to()` {#max.graph.TensorValue.broadcast_to} > broadcast\_to(shape) Broadcasts the tensor to a new shape. The following example demonstrates how to broadcast a tensor to a larger shape: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("broadcast_to_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Broadcast tensor to a 3x2x2 tensor (add a new dimension of size 3) broadcasted_tensor = tensor.broadcast_to((3, 2, 2)) print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Broadcasted shape: {broadcasted_tensor.shape}") # Output: [3, 2, 2] ``` **Parameters:** **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – An iterable of integers or symbolic dimensions. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with the broadcasted shape. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `cast()` {#max.graph.TensorValue.cast} > cast(dtype) Casts a symbolic tensor to a different data type. The following example demonstrates how to cast a tensor from one data type to another: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a matrix with float32 values matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("cast_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Cast tensor to integer type casted_tensor = tensor.cast(DType.int32) print(f"Original dtype: {tensor.dtype}") # Output: DType.float32 print(f"Casted dtype: {casted_tensor.dtype}") # Output: DType.int32 ``` **Parameters:** **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The target data type (e.g., `DType.int32`, `DType.float64`). **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with the casted data type. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `device` {#max.graph.TensorValue.device} > property device: [DeviceRef](type.md#max.graph.type.DeviceRef) Returns the device of the TensorValue. ### `dtype` {#max.graph.TensorValue.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) Returns the tensor data type. The following example demonstrates how to access the data type of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a matrix with float32 values matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("dtype_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor data type print(f"Data type: {tensor.dtype}") # Output: DType.float32 ``` ### `flatten()` {#max.graph.TensorValue.flatten} > flatten(start\_dim=0, end\_dim=-1) Flattens the specified dims of a symbolic tensor. The number and order of the elements in the tensor is unchanged. All dimensions from `start_dim` to `end_dim` (inclusive) are merged into a single output dim. The following example demonstrates how to flatten a multi-dimensional tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("flatten_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Flatten the tensor to a 1D array flattened_tensor = tensor.flatten() print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Flattened shape: {flattened_tensor.shape}") # Output: [4] ``` **Parameters:** * **start\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The starting dimension to flatten. Defaults to `1`. * **end\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The ending dimension to flatten. Defaults to `-1`. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with the flattened dimensions. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `from_mlir()` {#max.graph.TensorValue.from_mlir} > classmethod from\_mlir(value) Creates a [`TensorValue`](#max.graph.TensorValue) from an MLIR tensor value. **Parameters:** **value** (`Value` `[` `TensorType` `]` ) – The MLIR tensor value to wrap. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `permute()` {#max.graph.TensorValue.permute} > permute(dims) Permutes the tensor’s dimensions based on provided indices. **Parameters:** **dims** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – A list of integers specifying the new order of dimensions. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with permuted dimensions. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `print()` {#max.graph.TensorValue.print} > print(label='debug\_tensor') Prints detailed information about the tensor. **Parameters:** **label** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – A string label for the printed output. Defaults `debug_tensor`. ### `rank` {#max.graph.TensorValue.rank} > property rank: [int](https://docs.python.org/3/library/functions.html#int) Returns the rank (number of dims) of the buffer. The following example demonstrates how to access the rank of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix (2-dimensional array) matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("rank_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor rank (number of dimensions) print(f"Rank: {tensor.rank}") # Output: 2 ``` ### `rebind()` {#max.graph.TensorValue.rebind} > rebind(shape, message='') Rebinds the tensor to a new shape with error handling. **Parameters:** * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – The new shape as an iterable of integers or symbolic dimensions. * **message** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – (optional) A message for logging or debugging. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with the updated shape. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `reshape()` {#max.graph.TensorValue.reshape} > reshape(shape) Creates a new tensor with the same data but reshaped. The following example demonstrates how to reshape a tensor to change its dimensions: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("reshape_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Reshape tensor to a 1x4 matrix reshaped_tensor = tensor.reshape((1, 4)) print(f"Original shape: {tensor.shape}") # Output: [2, 2] print(f"Reshaped shape: {reshaped_tensor.shape}") # Output: [1, 4] ``` **Parameters:** **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – The new shape as an iterable of integers or symbolic dimensions. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with the reshaped dimensions. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `shape` {#max.graph.TensorValue.shape} > property shape: [Shape](type.md#max.graph.type.Shape) Returns the shape of the [`TensorValue`](#max.graph.TensorValue). The following example demonstrates how to access the shape of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("shape_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor shape print(f"Shape: {tensor.shape}") # Shape: [Dim(2), Dim(2)] ``` ### `to()` {#max.graph.TensorValue.to} > to(device) Transfers the tensor to a specified device without mutation. The following example demonstrates how to move a tensor from one device to another: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops, DeviceRef # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) with Graph("to_device_example") as graph: # Create a tensor on the default device tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Move the tensor to a GPU device gpu_tensor = tensor.to(DeviceRef.GPU()) print(f"Original device: {tensor.device}") # Output depends on default device print(f"New device: {gpu_tensor.device}") # Output: gpu:0 ``` **Parameters:** **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – A `DeviceRef` object specifying the target device. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) on the specified device. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `transpose()` {#max.graph.TensorValue.transpose} > transpose(dim\_1, dim\_2) Swaps two dimensions of the tensor. The following example demonstrates how to transpose a tensor by swapping its dimensions: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x3 matrix matrix = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) with Graph("transpose_demo") as graph: tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Transpose the tensor (swap dimensions 0 and 1) transposed_tensor = tensor.transpose(dim_1=0, dim_2=1) print(f"Original shape: {tensor.shape}") # Output: [2, 3] print(f"Transposed shape: {transposed_tensor.shape}") # Output: [3, 2] ``` **Parameters:** * **dim\_1** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The first dimension to swap. * **dim\_2** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The second dimension to swap. **Returns:** A new [`TensorValue`](#max.graph.TensorValue) with swapped dimensions. **Return type:** [*TensorValue*](#max.graph.TensorValue) ### `type` {#max.graph.TensorValue.type} > property type: [TensorType](type.md#max.graph.type.TensorType) Returns the type of the [`TensorValue`](#max.graph.TensorValue) as a `TensorType`. --- ## Value ## `Value` {#max.graph.Value} > class max.graph.Value Represents a symbolic value within a Graph. A Value can represent the output of a node, the arguments of a Graph (as seen from within its body), and more generally any symbolic value available within the Graph. Other nodes receive Value values as inputs to form a computation graph. A Value may also refer to an existing input or output of a node, and you can change them, such as by swapping a new Value. Conceptually, think of a Value as an edge in the dataflow graph, with the other end being the user of that value. The following example shows how to work with Values in a graph to create a simple computation: ```python from max.graph import Graph, ops, Value from max.dtype import DType import numpy as np with Graph("value_example") as graph: # Create input values a = ops.constant(np.array([1, 2, 3]), dtype=DType.float32, device=DeviceRef.CPU()) b = ops.constant(np.array([4, 5, 6]), dtype=DType.float32, device=DeviceRef.CPU()) # Use values to perform operations c = a + b # c is a Value representing the addition # Demonstrate that the result is a Value print(f"Type of c: {type(c)}") print(f"Is c a Value? {isinstance(c, Value)}") ``` Similar to a regular variable, a Value has a data type. Value is abstract, it shouldn’t be constructed directly. ### `buffer` {#max.graph.Value.buffer} > property buffer: [BufferValue](BufferValue.md#max.graph.BufferValue) Returns the Value as a [`BufferValue`](BufferValue.md#max.graph.BufferValue). Raises an exception if the Value is not a BufferValue. ### `from_mlir()` {#max.graph.Value.from_mlir} > classmethod from\_mlir(value) Creates a [`Value`](#max.graph.Value) from an MLIR value. **Parameters:** **value** (`Value` `[` `MlirType` `]` ) – The MLIR value to wrap. **Return type:** [*Value*](#max.graph.Value) ### `opaque` {#max.graph.Value.opaque} > property opaque: \_OpaqueValue Returns the Value as an `_OpaqueValue`. Raises an exception if the Value is not a \_OpaqueValue. ### `tensor` {#max.graph.Value.tensor} > property tensor: [TensorValue](TensorValue.md#max.graph.TensorValue) Returns the Value as a [`TensorValue`](TensorValue.md#max.graph.TensorValue). Raises an exception if the Value is not a TensorValue. ### `type` {#max.graph.Value.type} > property type: [Type](type.md#max.graph.type.Type)\[MlirType] Returns the type of the [`Value`](#max.graph.Value) as a `Type`. --- ## Weight ## `Weight` {#max.graph.Weight} > class max.graph.Weight(\*args, \*\*kwargs) Bases: [`TensorValue`](TensorValue.md#max.graph.TensorValue) Represents a value in a Graph that can be loaded at a later time. Weights can be initialized outside of a Graph and are lazily-added to the parent graph when used. If there is no parent graph when a weight is used, an error will be raised. Initializes a [`TensorValue`](TensorValue.md#max.graph.TensorValue) from a tensor-like value. **Parameters:** **value** – The value to wrap. Can be an MLIR tensor value, another [`TensorValue`](TensorValue.md#max.graph.TensorValue), a `Dim`, or a `Shape`. ### `align` {#max.graph.Weight.align} > align: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `device` {#max.graph.Weight.device} > property device: [DeviceRef](type.md#max.graph.type.DeviceRef) Returns the device of the TensorValue. ### `dtype` {#max.graph.Weight.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) Returns the tensor data type. The following example demonstrates how to access the data type of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("dtype_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor data type print(f"Data type: {tensor.dtype}") # Output: DType.float32 ``` ### `original_dtype_and_shape` {#max.graph.Weight.original_dtype_and_shape} > property original\_dtype\_and\_shape: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[DType](../dtype.md#max.dtype.DType), [Shape](type.md#max.graph.type.Shape)] The original dtype and shape of this weight. This property should be used to store the original weight’s dtype and shape the quantization encoding forces the weight to be loaded as uint8. ### `quantization_encoding` {#max.graph.Weight.quantization_encoding} > quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) ### `shape` {#max.graph.Weight.shape} > property shape: [Shape](type.md#max.graph.type.Shape) Returns the shape of the [`TensorValue`](TensorValue.md#max.graph.TensorValue). The following example demonstrates how to access the shape of a tensor: ```python import numpy as np from max.dtype import DType from max.graph import Graph, ops # Create a 2x2 matrix matrix = np.array([[1, 2], [3, 4]], dtype=np.float32) # Create a Graph context to work with tensors with Graph("shape_demo") as graph: # Create a constant tensor from the matrix tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU()) # Access tensor shape print(f"Shape: {tensor.shape}") # Shape: [Dim(2), Dim(2)] ``` ### `shard()` {#max.graph.Weight.shard} > shard(shard\_idx, device) Gets a specific shard from the Weight. This Weight must have sharding\_strategy defined. The shard object returned is also a Weight object, but cannot be sharded further. **Parameters:** * **shard\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – int value of the shard. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – device to place the shard. **Returns:** The sharded weight. **Return type:** [*Weight*](#max.graph.Weight) ### `shard_idx` {#max.graph.Weight.shard_idx} > shard\_idx: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `sharding_strategy` {#max.graph.Weight.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Gets the weight sharding strategy. --- ## graph APIs to build inference graphs for MAX Engine with Python. ## Classes * [`BufferValue`](/max/api/python/graph/BufferValue): Represents a mutable semantic tensor within a Graph. * [`Graph`](/max/api/python/graph/Graph): Represents a graph for MAX Engine. * [`KernelLibrary`](/max/api/python/graph/KernelLibrary): Represents a library with custom ops. * [`TensorValue`](/max/api/python/graph/TensorValue): Represents a value semantic tensor within a Graph. * [`Value`](/max/api/python/graph/Value): Represents a symbolic value within a Graph. * [`Weight`](/max/api/python/graph/Weight): Represents a weight value in a graph. ## Modules * [`ops`](/max/api/python/graph/ops): Ops you can add when staging a graph. * [`quantization`](/max/api/python/graph/quantization): APIs to quantize graph tensors. * [`type`](/max/api/python/graph/type): APIs for graph value types. * [`weights`](/max/api/python/graph/weights): APIs for loading weights into a graph. --- ## ops Implements operations used when staging a graph. This module provides operations for building computational graphs in MAX. These operations create, transform, and manipulate tensor values within the graph. You can also use functions in [Graph](/max/api/python/graph/Graph) to add constant values to your graph with operations like [constant()](/max/api/python/graph/ops#max.graph.ops.constant). The [TensorValue](/max/api/python/graph/TensorValue/) type (returned by most operations) implements various dunder methods to support operations between TensorValues, such as + for addition, \* for multiplication, and @ for matrix multiplication. It also provides convenience methods like [reshape()](/max/api/python/graph/TensorValue/#max.graph.TensorValue.reshape) and [flatten()](/max/api/python/graph/TensorValue/#max.graph.TensorValue.flatten). ### `InterpolationMode` {#max.graph.ops.InterpolationMode} > class max.graph.ops.InterpolationMode(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Interpolation modes for image resize operations. This enum defines the available interpolation methods that can be used when resizing tensors. Currently only BICUBIC is implemented, with BILINEAR and NEAREST planned for future support. #### `BICUBIC` {#max.graph.ops.InterpolationMode.BICUBIC} > BICUBIC = 'bicubic' #### `BILINEAR` {#max.graph.ops.InterpolationMode.BILINEAR} > BILINEAR = 'bilinear' #### `NEAREST` {#max.graph.ops.InterpolationMode.NEAREST} > NEAREST = 'nearest' ### `abs()` {#max.graph.ops.abs} > max.graph.ops.abs(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `add()` {#max.graph.ops.add} > max.graph.ops.add(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `allgather()` {#max.graph.ops.allgather} > max.graph.ops.allgather(inputs, axis=0) Collective allgather operation. This op is a collective op which takes in tensors from different devices and outputs tensors on different devices. In particular, this operation will gather the inputs across different devices and concatenates them along the specified dimension. The result is then broadcasted back to the same devices that the inputs came from. **Parameters:** * **inputs** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `]` ) – The input tensors to gather. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Dimension to concatenate the input tensors. Defaults to 0. **Returns:** An iterable outputs which all hold the gathered output. Each output tensor contains the concatenation of all inputs along the specified dimension. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*TensorValue*](TensorValue.md#max.graph.TensorValue)] ### `argmax()` {#max.graph.ops.argmax} > max.graph.ops.argmax(x, axis=-1) Reduces a symbolic tensor using an argmax operation. When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor for the operation. * **axis** – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension. **Returns:** A symbolic tensor representing the result of the argmax operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `argmin()` {#max.graph.ops.argmin} > max.graph.ops.argmin(x, axis=-1) Reduces a symbolic tensor using an argmin operation. When provided with a tensor with all identical elements, on CPU this will return the first element index in the tensor, on GPU this will return an arbitrary index. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor for the operation. * **axis** – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension. **Returns:** A symbolic tensor representing the result of the argmin operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `argsort()` {#max.graph.ops.argsort} > max.graph.ops.argsort(x, ascending=True) Returns the indices that would sort a tensor. This function returns the indices that would sort the input tensor along its first dimension. The returned indices are of type int64. **Parameters:** * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – Input tensor to be sorted. * **ascending** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If True (default), sort in ascending order. If False, sort in descending order. **Returns:** A tensor of indices of the same shape as the input tensor. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `as_interleaved_complex()` {#max.graph.ops.as_interleaved_complex} > max.graph.ops.as\_interleaved\_complex(x) Reshapes the input symbolic tensor as complex from alternating (real, imag). **Parameters:** * **interleaved** – A symbolic tensor representing complex numbers as alternating pairs of (real, imag) real-valued numbers. Its last dimension must have an even size. * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Returns:** A symbolic tensor representing the complex-valued tensor, but with the values pulled out as complex numbers. The result has the same dimensions for all dimensions except the last dimension, which is halved, and then a final dimension of size 2 representing the complex value. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `atanh()` {#max.graph.ops.atanh} > max.graph.ops.atanh(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `band_part()` {#max.graph.ops.band_part} > max.graph.ops.band\_part(x, num\_lower=None, num\_upper=None, exclude=False) Masks out everything except a diagonal band of an input matrix. Copies a tensor setting everything outside the central diagonal band of the matrices to zero, where all but the last two axes are effectively batches, and the last two axes define sub matrices. Assumes the input has dimensions \[I, J, …, M, N], then the output tensor has the same shape as the input, and the values are given by ```python out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n]. ``` with the indicator function: ```python in_band(m, n) = ((num_lower is None || (m - n) **Parameters:** * **input** – The input to mask out. * **num\_lower** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The number of diagonal bands to include below the central diagonal. If None, include the entire lower triangle. * **num\_upper** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The number of diagonal bands to include above the central diagonal. If None, include the entire upper triangle. * **exclude** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If true, invert the selection of elements to mask. Elements in the band are set to zero. * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Returns:** A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input tensor rank is less than 2, or if num\_lower/num\_upper are out of bounds for statically known dimensions. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `broadcast_to()` {#max.graph.ops.broadcast_to} > max.graph.ops.broadcast\_to(x, shape, out\_dims=None) Broadcasts a symbolic tensor. Broadcasts the input tensor to the specified shape. Dimensions in the input must be one or match the target dimension. **Parameters:** * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The input symbolic tensor to broadcast. This tensor may not contain any dynamic dimensions. * **shape** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – The new shape as a list of dimensions. Dynamic dimensions are not allowed. * **out\_dims** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `None` ) – Output dims used only for tensor-valued shape. **Returns:** A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as `shape`. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – if a tensor-valued shape is passed without out\_dims. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `buffer_load()` {#max.graph.ops.buffer_load} > max.graph.ops.buffer\_load(x) Loads the input buffer into a tensor. It loads the in-place mutable tensor to an immutable tensor graph value. This is semantically equivalent to a copy from the mutable tensor x to the mutable value-semantic tensor output. **Parameters:** **x** ([`BufferValue`](BufferValue.md#max.graph.BufferValue) ) – The buffer to be loaded to a tensor. **Returns:** A tensor graph value representing a copy of the buffer loaded. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `buffer_store()` {#max.graph.ops.buffer_store} > max.graph.ops.buffer\_store(destination, source) Stores the input tensor into the inout buffer. It stores the immutable input tensor x in the mutable tensor y. This is semantically equivalent to a copy from x tensor to the y buffer. **Parameters:** * **x** – The tensor to be stored in the buffer. * **y** – The buffer to store the tensor in. * **destination** ([`BufferValue`](BufferValue.md#max.graph.BufferValue) ) * **source** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) **Return type:** None ### `buffer_store_slice()` {#max.graph.ops.buffer_store_slice} > max.graph.ops.buffer\_store\_slice(destination, source, indices) Stores the input tensor to into a slice in the input buffer. It stores the immutable input tensor source in the mutable tensor destination. This is semantically equivalent to a copy from source tensor to a slice in the destination buffer at index specified by indices. **Parameters:** * **destination** ([`BufferValue`](BufferValue.md#max.graph.BufferValue) ) – The buffer to store the tensor in. * **source** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The tensor to be stored in the buffer. * **indices** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`slice`](https://docs.python.org/3/library/functions.html#slice) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`slice`](https://docs.python.org/3/library/functions.html#slice) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `EllipsisType` `]` ) – The index in the buffer where the tensor should be stored **Return type:** None ### `call()` {#max.graph.ops.call} > max.graph.ops.call(graph, \*args, prefix='') Call a graph with the provided arguments and return its results. This function invokes a previously defined graph, passing in the provided arguments and the current chain value, and returns the results. The body of the graph is ultimately inlined into the caller, so the chain value is only used for serialization if the subgraph’s body contains an operation that makes use of it in the first place. The current advantage of using subgraphs is that it offers a way to improve compile times for operations that are used repeatedly in a model. As a secondary benefit, it also makes the IR more readable by allowing control flow to be expressed in a more natural way. **Parameters:** * **graph** ([`Graph`](Graph.md#max.graph.Graph) ) – The graph to call * **\*args** ([`Value`](Value.md#max.graph.Value) ) – Arguments to pass to the called graph * **prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Prefix to add to the names of any weights in the subgraph **Returns:** Either a single Value or a list of Values representing the graph outputs (excluding the chain value which is handled internally) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Value*](Value.md#max.graph.Value)] ### `cast()` {#max.graph.ops.cast} > max.graph.ops.cast(x, dtype) Casts a symbolic tensor to a different data type. **Parameters:** * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The input tensor to cast. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The target dtype to which the tensor is cast. **Returns:** A new symbolic tensor with the same shape as the input and the specified dtype. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `chunk()` {#max.graph.ops.chunk} > max.graph.ops.chunk(x, chunks, axis=0) Chunk the tensor into an exact number of chunks along the specified dim. ```python a = TensorValue([1, 2, 3, 4, 5]) chunk(a, 2, 0) ``` **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The tensor to chunk. * **chunks** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of chunks to split the tensor into. chunks must statically evenly divide x.shape\[axis]. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis to split the tensor along. **Returns:** A list of chunks tensors. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*TensorValue*](TensorValue.md#max.graph.TensorValue)] ### `concat()` {#max.graph.ops.concat} > max.graph.ops.concat(original\_vals, axis=0) Concatenates a list of symbolic tensors along an axis. **Parameters:** * **original\_vals** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` `Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `]` ) – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension other than `axis`. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis to concatenate along. If negative, indexes relative to the end of the tensor shape. For instance, `concat(vs, -1)` will concat along the last dimension. **Returns:** A new symbolic tensor representing the concatenation result. It will have the same rank as each input tensor, and its dimensions will be the same as each input tensor’s for each dimension other than axis, which will have size equal to the sum of all tensor’s size for that dimension. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `cond()` {#max.graph.ops.cond} > max.graph.ops.cond(pred, out\_types, then\_fn, else\_fn) Conditionally execute one of two branches based on a boolean predicate. Both branches must return the same number and types of values as specified in `out_types`. Buffer mutations in branches are tracked automatically through the chain mechanism. Examples: 1. Basic conditional with return values: > ```python > def then_fn(): > return ops.constant(1, DType.int32, device=DeviceRef.CPU()) > def else_fn(): > return ops.constant(0, DType.int32, device=DeviceRef.CPU()) > ​ > result = ops.cond( > pred, > [TensorType(DType.int32, [], device=device)], > then_fn, > else_fn > ) > ``` 2. Conditional with buffer mutations: > ```python > def then_fn(): > ops.inplace_custom("increment", device=buffer.device, values=[buffer]) > def else_fn(): > ops.inplace_custom("decrement", device=buffer.device, values=[buffer]) > ​ > ops.cond(pred, None, then_fn, else_fn) > ``` :: :param pred: Boolean scalar tensor of type `DType.bool` determining branch execution :param out\_types: Expected output types for both branches. Use [`None`](https://docs.python.org/3/library/constants.html#None) for branches that don’t return values :param then\_fn: Callable executed when `pred` is True. Must return values matching `out_types` if `out_types` is not [`None`](https://docs.python.org/3/library/constants.html#None) :param else\_fn: Callable executed when `pred` is False. Must return values matching `out_types` if `out_types` is not [`None`](https://docs.python.org/3/library/constants.html#None) **Returns:** List of output values from executed branch. Returns empty list when `out_types` is [`None`](https://docs.python.org/3/library/constants.html#None) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If branches return different numbers of results or result types don’t match `out_types` **Parameters:** * **pred** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **out\_types** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Type`](type.md#max.graph.type.Type) `]` `|` `None` ) * **then\_fn** ([`Callable`](https://docs.python.org/3/library/typing.html#typing.Callable) ) * **else\_fn** ([`Callable`](https://docs.python.org/3/library/typing.html#typing.Callable) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*TensorValue*](TensorValue.md#max.graph.TensorValue)] ##### NOTE Buffer operations in branches automatically update the global chain state to maintain mutation ordering constraints ### `constant()` {#max.graph.ops.constant} > max.graph.ops.constant(value, dtype, device) Adds a node representing a constant operation. The value of this constant will have the type TensorType with the same shape as value. If value is a scalar type, it will create a TensorType with 0 dimensions. The constant will be loaded with the specified dtype. If the constant does not fit within the specified dtype, an error is raised. Warning: Loading the constant could result in precision loss. For example, loading 16777217 as a float32 will result in 16777216.0. **Parameters:** * **value** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) ) – The constant’s value. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The constant tensor’s element type. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – The device the constant lives on. **Returns:** A graph value containing the constant data as an attribute. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `conv2d()` {#max.graph.ops.conv2d} > max.graph.ops.conv2d(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF) Computes the 2-D convolution product of the input with the given filter, bias, strides, dilations, paddings, and groups. The op supports 2-D convolution, with the following layout assumptions: * input x has NHWC layout, i.e., (batch\_size, height, width, in\_channels) * filter has layout RSCF, i.e., (height, width, in\_channels / num\_groups, out\_channels) * bias has shape (out\_channels,) The padding values are expected to take the form (pad\_dim1\_before, pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding 0’s before and after the indicated *spatial* dimensions in input. In 2-D convolution, dim1 here represents H and dim2 represents W. In Python like syntax, padding a 2x3 spatial input with \[0, 1, 2, 1] would yield: ```python input = [ [1, 2, 3], [4, 5, 6] ] ## Shape is 2x3 padded_input = [ [0, 0, 1, 2, 3, 0], [0, 0, 4, 5, 6, 0], [0, 0, 0, 0, 0, 0] ] ## Shape is 3x6 ``` This op currently only supports strides and padding on the input. **Parameters:** * **input** – An NHWC input tensor to perform the convolution upon. * **filter** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The convolution filter in RSCF layout: (height, width, in\_channels / num\_groups, out\_channels). * **stride** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The stride of the convolution operation. * **dilation** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The spacing between the kernel points. * **padding** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The amount of padding applied to the input. * **groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups. * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **input\_layout** (`ConvInputLayout` ) * **filter\_layout** (`FilterLayout` ) **Returns:** A symbolic tensor value with the convolution applied. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `conv2d_transpose()` {#max.graph.ops.conv2d_transpose} > max.graph.ops.conv2d\_transpose(x, filter, stride=(1, 1), dilation=(1, 1), padding=(0, 0, 0, 0), output\_paddings=(0, 0), bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.RSCF) Computes the 2-D deconvolution of the input with the given filter, strides, dilations, paddings, and groups. The op supports the transpose (gradient) of convolution, with the following layout assumptions: (note the out\_channel is w\.r.t. the original convolution) * input x has NHWC layout, i.e., (batch\_size, height, width, in\_channels) * filter has layout RSCF, i.e., (kernel\_height, kernel\_width, out\_channels, in\_channels) * bias has shape (out\_channels,) The padding values are expected to take the form in the form \[\[0, 0], \[pad\_top, pad\_bottom], \[pad\_left, pad\_right], \[0, 0]]. This op effectively computes the gradient of a convolution with respect to its input (as if the original convolution operation had the same filter and hyperparameters as this op). A visualization of the computation can be found in . The padding values are expected to take the form (pad\_dim1\_before, pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding 0’s before and after the indicated *spatial* dimensions in input. In 2D ConvTranspose, dim1 here represents H\_out and dim2 represents W\_out. In python like syntax, padding a 2x4 spatial output with \[0, 1, 2, 1] would yield: ```python output = [ [1, 2, 3, 4], [5, 6, 7, 8] ] ## Shape is 2x4 padded_input = [ [3], ] ## Shape is 1x1 ``` **Parameters:** * **input** – An NHWC input tensor to perform the convolution upon. * **filter** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The convolution filter in RSCF layout: (height, width, out\_channels, in\_channels). * **stride** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0. * **dilation** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The spacing between the kernel points. * **padding** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The amount of padding applied to the input. * **output\_paddings** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – this argument is meant to resolve the ambiguity of multiple potential output shapes when any stride is greater than 1. Basically, we’ll add output\_paddings\[i] number of zeros at the end of output’s ith axis. We only support output\_paddings = 0. * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) – tensor of shape (out\_channels,) * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **input\_layout** (`ConvInputLayout` ) * **filter\_layout** (`FilterLayout` ) **Returns:** A symbolic tensor value with the convolution applied. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `conv3d()` {#max.graph.ops.conv3d} > max.graph.ops.conv3d(x, filter, stride=(1, 1, 1), dilation=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), groups=1, bias=None, input\_layout=ConvInputLayout.NHWC, filter\_layout=FilterLayout.QRSCF) Computes the 3-D convolution product of the input with the given filter, strides, dilations, paddings, and groups. The op supports 3-D convolution, with the following layout assumptions: * input has NDHWC layout, i.e., (batch\_size, depth, height, width, in\_channels) * filter has layout RSCF, i.e., (depth, height, width, in\_channels / num\_groups, out\_channels) The padding values are expected to take the form (pad\_dim1\_before, pad\_dim1\_after, pad\_dim2\_before, pad\_dim2\_after…) and represent padding 0’s before and after the indicated *spatial* dimensions in input. In 3-D convolution, dim1 here represents D, dim2 represents H and dim3 represents W. In Python like syntax, padding a 2x3 spatial input with \[0, 1, 2, 1] would yield: ```python input = [ [1, 2, 3], [4, 5, 6] ] ## Shape is 2x3 padded_input = [ [0, 0, 1, 2, 3, 0], [0, 0, 4, 5, 6, 0], [0, 0, 0, 0, 0, 0] ] ## Shape is 3x6 ``` This op currently only supports strides and padding on the input. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – An NDHWC input tensor to perform the convolution upon. * **filter** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The convolution filter in RSCF layout: (depth, height, width, in\_channels / num\_groups, out\_channels). * **stride** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The stride of the convolution operation. * **dilation** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The spacing between the kernel points. * **padding** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The amount of padding applied to the input. * **groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – When greater than 1, divides the convolution into multiple parallel convolutions. The number of input and output channels must both be divisible by the number of groups. * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **input\_layout** (`ConvInputLayout` ) * **filter\_layout** (`FilterLayout` ) **Returns:** A symbolic tensor value with the convolution applied. Output shape = (batch\_size, depth, height, width, out\_channels). **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `cos()` {#max.graph.ops.cos} > max.graph.ops.cos(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `cumsum()` {#max.graph.ops.cumsum} > max.graph.ops.cumsum(x, axis=-1, exclusive=False, reverse=False) Computes the cumulative sum of the input tensor along the given axis. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor to sum over. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis along which to compute the sum. If negative, indexes from the last dimension. For example, a value of -1 will compute the sum along the last dimension. * **exclusive** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If set, start at 0 and exclude the final element. Otherwise, start with the first element. Said another way, cumsum computes \[sum(x\[…, :i, …]) for i in range(x.shape\[axis])]. If exclusive is set, the bounds are instead range(1, x.shape\[axis]). * **reverse** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If set, start from the end. In other words, the first element will be the total sum, with each element following counting downwards; or \[sum(x\[…, i:, …]) for i in range(x.shape\[axis])]. **Returns:** A symbolic tensor representing the result of the cumsum operation. The tensor will have the same type as the input tensor. The computed values will be the cumulative sum of the values along the given axis, according to the specified parameters: * if exclusive is set, the first value will be 0, and the last value will be excluded from the sum * if reverse is set, the sum will be computed starting at the back of the axis back to the front, rather than front-to-back **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `custom()` {#max.graph.ops.custom} > max.graph.ops.custom(name, device, values, out\_types, parameters=None) Creates a node to execute a custom graph operation in the graph. The custom op should be registered by annotating a function with the [@compiler.register](/max/api/mojo-decorators/compiler-register/) decorator. **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The op name provided to `@compiler.register`. * **values** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Value`](Value.md#max.graph.Value) `]` ) – The op function’s arguments. * **out\_types** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Type`](type.md#max.graph.type.Type) `]` ) – The list of op function’s return type. * **parameters** ([`Mapping`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`bool`](https://docs.python.org/3/library/functions.html#bool) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`DType`](../dtype.md#max.dtype.DType) `]` `|` `None` ) – Dictionary of extra parameters expected by the kernel. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – Device that the op is assigned to. This becomes a target parameter to the kernel. **Returns:** Symbolic values representing the outputs of the op in the graph. These correspond 1:1 with the types passed as `out_types`. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Value*](Value.md#max.graph.Value)] ### `dequantize()` {#max.graph.ops.dequantize} > max.graph.ops.dequantize(encoding, quantized) Dequantizes a quantized tensor to floating point. NOTE: Currently this supports Q4\_0, Q4\_K, and Q6\_K encodings only. **Parameters:** * **encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) ) – The quantization encoding to use. * **quantized** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The quantized tensor to dequantize. **Returns:** The dequantized result (a floating point tensor). **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `div()` {#max.graph.ops.div} > max.graph.ops.div(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `equal()` {#max.graph.ops.equal} > max.graph.ops.equal(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `erf()` {#max.graph.ops.erf} > max.graph.ops.erf(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `exp()` {#max.graph.ops.exp} > max.graph.ops.exp(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `flatten()` {#max.graph.ops.flatten} > max.graph.ops.flatten(x, start\_dim=0, end\_dim=-1) Flattens the specified dims of a symbolic tensor. The number and order of the elements in the tensor is unchanged. All dimensions from start\_dim to end\_dim (inclusive) are merged into a single output dim. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **start\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **end\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `floor()` {#max.graph.ops.floor} > max.graph.ops.floor(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `fold()` {#max.graph.ops.fold} > max.graph.ops.fold(input, output\_size, kernel\_size, stride=1, dilation=1, padding=0) Combines an array of sliding blocks into a larger containing tensor. The input tensor must have shape `(N, C * kernel_sizes, L)` where `N` is the batch dimension, `C` is the number of channels, `kernel_sizes` is the product of the kernel sizes, and `L` is the number of local blocks. The resulting output tensor will have shape `(N, C, output_shape[0], output_shape[1])`. `L`, the number of blocks, must be equivalent to: `prod((output_size[d] + 2 * padding[d] - dilation[d] * (kernel_size[d] - 1) - 1) / stride[d] + 1)` where `d` is over all spatial dimensions. **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The 3D tensor to fold with shape `(N, C * kernel sizes, L)`. * **output\_size** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – Spacial dimensions of the output tensor. Must be a tuple of two ints. * **kernel\_size** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – The size of the sliding blocks. Must be a tuple of two ints. * **stride** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The stride of the sliding blocks in the input dimension (can be an int or a tuple of two ints). * **dilation** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The spacing between the kernel elements. (can be an int or a tuple of two ints). * **padding** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – 0-paddings to be added on both sides of the inputs. (can be an int or a tuple of two ints). **Returns:** The folded 4D tensor with shape `(N, C, output_shape[0], output_shape[1])`. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `gather()` {#max.graph.ops.gather} > max.graph.ops.gather(input, indices, axis=-1) Selects elements out of an input tensor by index. **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to select elements from. * **indices** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – A symbolic tensor of index values to use for selection. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension which `indices` indexes from `input`. If negative, indexes relative to the end of the input tensor. For instance, `gather(input, indices, axis=-1)` will index against the last dimension of `input`. **Returns:** A new symbolic tensor representing the result of the gather operation. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `gather_nd()` {#max.graph.ops.gather_nd} > max.graph.ops.gather\_nd(input, indices, batch\_dims=0) Selects elements out of an input tensor by N-dimensional index. This operation performs N-dimensional indexing into `input` using `indices`. Unlike [`gather()`](#max.graph.ops.gather), which indexes along a single axis, `gather_nd()` allows indexing along multiple dimensions simultaneously. ```python input_shape = ["a", "b", "c", "d", "e"] indices_shape = ["a", "f", 3] input_type = TensorType(DType.bfloat16, input_shape) indices_type = TensorType(DType.int32, indices_shape) with Graph("gather_nd", input_types=[input_type, indices_type]) as graph: input, indices = graph.inputs gathered = ops.gather_nd(input, indices, batch_dims=1) print(gathered.type) ## Output: TensorType(dtype=DType.bfloat16, shape=["a", "f", "e"]) ``` In this example: * `batch_dims` is 1, so there’s 1 shared dimension at the beginning. * `indices` has an additional dimension “f” which becomes part of the output. * The last dimension of `indices` is the index vector; values in this vector are interpreted to be indices into “b”, “c”, and “d”. * Since `batch_dims (1) + index size (3) **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to select elements from. * **indices** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – A symbolic tensor of index values to use for selection. The last dimension of this tensor must be static. This dimension will be used to index or slice into `input` immediately following `batch_dims` initial dimensions. The size of this index dimension is the number of dimensions it specifies. * **batch\_dims** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of leading batch dimensions shared by `input` and `indices`; 0 by default. `input` and `indices` must exactly match up to their first `batch_dims` dimensions. This function does not broadcast. **Returns:** A new symbolic tensor representing the result of the gather operation. The output will have the same dtype as `input`, and will have shape depending on the inputs, in this order: * `input.shape[:batch_dims]` – The “broadcast” dimensions (though note that this function does not broadcast). These dimensions must be identical between `input` and `indices`. * `indices.shape[batch_dims:-1]` – The “gather” dimensions; this allows multi-dimensional tensors of indices. The last dimension is the index vector. * `input.shape[batch_dims + indices.shape[-1]:]` – The “slice” dimensions. If `batch_dims` **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `gelu()` {#max.graph.ops.gelu} > max.graph.ops.gelu(x, approximate='none') Computes the elementwise gelu of a symbolic tensor. Creates a new op node to compute the elementwise gelu of a symbolic tensor and adds it to the graph, returning the symbolic result. For `approximate == "none"`, the exact gelu function is computed. For `approximate == "tanh"`, the approximation: $$ gelu(x) = 0.5 * x * (1.0 + tanh(0.7978845608028654 * (x + 0.044715 * x**3))) $$ is used. For `approximate == "quick"`, the approximation: $$ gelu(x) = sigmoid(1.702 * x) * x $$ is used. **Parameters:** * **value** – The symbolic tensor to use as the input to the gelu computation. * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) * **approximate** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) **Returns:** A new symbolic tensor value representing the output of the gelu value computation. **Raises:** * **Error** – If the symbol doesn’t represent a tensor value. * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the approximation method is invalid. ### `greater()` {#max.graph.ops.greater} > max.graph.ops.greater(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `greater_equal()` {#max.graph.ops.greater_equal} > max.graph.ops.greater\_equal(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `hann_window()` {#max.graph.ops.hann_window} > max.graph.ops.hann\_window(window\_length, device, periodic=True, dtype=float32) Calculate a Hann window for a given length. Hann window function: $$ H[n] = 1/2 [1 - cos(2 * pi * n / (N - 1))] $$ where N is window\_length. **Parameters:** * **window\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The length of the window. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – The device to run the operation on. * **periodic** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – bool flag determines whether the returned window trims off the last duplicate value from the symmetric window and is ready to be used as a periodic window with functions like stft(). hann\_window(L, periodic=True) == hann\_window(L + 1, periodic=False)\[:-1]) * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The desired data type of the output tensor. **Returns:** A 1-D tensor of size (window\_length,) containing the window. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If window\_length is negative. * [**TypeError**](https://docs.python.org/3/library/exceptions.html#TypeError) – If window\_length is not an integer. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `inplace_custom()` {#max.graph.ops.inplace_custom} > max.graph.ops.inplace\_custom(name, device, values, out\_types=None, parameters=None) Creates a node to execute an in-place custom graph operation in the graph. The custom op should be registered by annotating a function with the [@compiler.register](/max/api/mojo-decorators/compiler-register/) decorator. **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The op name provided to `@compiler.register`. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – Device that the op is assigned to. This becomes a target parameter to the kernel. * **values** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Value`](Value.md#max.graph.Value) `]` ) – The op function’s arguments. * **parameters** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`bool`](https://docs.python.org/3/library/functions.html#bool) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`DType`](../dtype.md#max.dtype.DType) `]` `|` `None` ) – Dictionary of extra parameters expected by the kernel. * **out\_types** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Type`](type.md#max.graph.type.Type) `]` `|` `None` ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Value*](Value.md#max.graph.Value)] ### `irfft()` {#max.graph.ops.irfft} > max.graph.ops.irfft(input\_tensor, n=None, axis=-1, normalization=Normalization.BACKWARD, input\_is\_complex=False) Compute the inverse real FFT of the input tensor. **Parameters:** * **input\_tensor** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The input tensor to compute the inverse real FFT of. * **n** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The size of the output tensor. Must be an int, and cannot be a symbolic Tensor. The input tensor will be padded or truncated to n // 2 + 1 along the specified axis. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis to compute the inverse real FFT of. * **normalization** (`Normalization` `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The normalization to apply to the output tensor. Can be “backward”, “ortho”, or “forward”. When “backward”, the output is divided by n. When “ortho”, the output is divided by sqrt(n). When “forward”, no normalization is applied. * **input\_is\_complex** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether the input tensor is already interleaved complex. The last dimension of the input tensor must be 2, and is excluded from the dimension referred to by axis. **Returns:** The inverse real FFT of the input tensor. The shape of the output tensor is the same as the shape of the input tensor, except for the axis that the inverse real FFT is computed over, which is replaced by n. ### `is_inf()` {#max.graph.ops.is_inf} > max.graph.ops.is\_inf(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `is_nan()` {#max.graph.ops.is_nan} > max.graph.ops.is\_nan(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `layer_norm()` {#max.graph.ops.layer_norm} > max.graph.ops.layer\_norm(input, gamma, beta, epsilon) Performs layer normalization. **Parameters:** * **input** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The input tensor to normalize. * **gamma** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The gamma parameter of the normalization. * **beta** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The beta parameter of the normalization. * **epsilon** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – The epsilon parameter of the normalization. **Returns:** A graph tensor value with the normalization applied. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If gamma size doesn’t match the last dimension of input. * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If beta size doesn’t match the last dimension of input. * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If epsilon is not positive. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `log()` {#max.graph.ops.log} > max.graph.ops.log(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `log1p()` {#max.graph.ops.log1p} > max.graph.ops.log1p(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `logical_and()` {#max.graph.ops.logical_and} > max.graph.ops.logical\_and(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `logical_not()` {#max.graph.ops.logical_not} > max.graph.ops.logical\_not(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `logical_or()` {#max.graph.ops.logical_or} > max.graph.ops.logical\_or(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `logical_xor()` {#max.graph.ops.logical_xor} > max.graph.ops.logical\_xor(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `logsoftmax()` {#max.graph.ops.logsoftmax} > max.graph.ops.logsoftmax(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `masked_scatter()` {#max.graph.ops.masked_scatter} > max.graph.ops.masked\_scatter(input, mask, updates, out\_dim) Creates a new symbolic tensor where the updates are written to input where mask is true. **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to write elements to. * **mask** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – A symbolic tensor of boolean values to update. * **updates** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – A symbolic tensor of elements to write to input. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) ) – The new data-dependent dimension. **Returns:** A new symbolic tensor representing the result of the masked\_scatter operation. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `matmul()` {#max.graph.ops.matmul} > max.graph.ops.matmul(lhs, rhs) Computes the matrix multiplication of two tensor graph values. Performs general matrix multiplication with broadcasting. If the lhs is 1D, it will be reshaped to `1xD`. If the rhs is 1D, it will be reshaped to `Dx1`. In both cases, the additional 1 dimensions will be removed from the output shape. For the multiplication, the innermost (rightmost) 2 dimensions are treated as a matrix. The lhs matrix will have the shape `MxK`. The rhs matrix will have the shape `KxN`. The output will have the shape MxN The `K` dimensions must be equivalent in both matrices. The remaining outer dimensions will be broadcasted. **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The left-hand-side of the matmul. * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The right-hand-side of the matmul. * **location** – An optional location for a more specific error message. **Returns:** A tensor graph value representing he result of broadcasting the two matrices together and then performing a matrix multiply along the innermost two dimension of each tensor. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `max()` {#max.graph.ops.max} > max.graph.ops.max(x, y=None, /, axis=None) Overload for ops.elementwise.max and ops.reduction.max. * If two tensors are provided, axis is ignored and returns an elementwise maximum. * If one tensor is provided, compute ops.reduction.max on the tensor and axis. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **y** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `mean()` {#max.graph.ops.mean} > max.graph.ops.mean(x, axis=-1) Reduces a symbolic tensor using a mean operation. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor for the operation. * **axis** – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension. **Returns:** A symbolic tensor representing the result of the mean operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `min()` {#max.graph.ops.min} > max.graph.ops.min(x, y=None, /, axis=None) Overload for ops.elementwise.min and ops.reduction.min. * If two tensors are provided, axis is ignored and returns an elementwise minimum. * If one tensor is provided, compute ops.reduction.min on the tensor and axis. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **y** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `mod()` {#max.graph.ops.mod} > max.graph.ops.mod(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `mul()` {#max.graph.ops.mul} > max.graph.ops.mul(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `negate()` {#max.graph.ops.negate} > max.graph.ops.negate(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `nonzero()` {#max.graph.ops.nonzero} > max.graph.ops.nonzero(x, out\_dim) Returns the indices of all nozero elements in a tensor. Returns a tensor of indices of the nonzero values in the given tensor. The return value is a 2D tensor of shape `[out_dim x rank_in]`, where out\_dim is the number of nonzero elements in the input tensor, and rank\_in is the rank of the input tensor. Indices are generated in row-major order. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) ) – The newly generated dimension that is sized for the number of nonzero elements. **Returns:** A symbolic tensor of indices **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `not_equal()` {#max.graph.ops.not_equal} > max.graph.ops.not\_equal(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `outer()` {#max.graph.ops.outer} > max.graph.ops.outer(lhs, rhs) Computes the outer product of two symbolic vectors. **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The left side of the product. Whatever its shape, it will be flattened to a rank-1 vector. * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The right side of the product. Whatever its shape, it will be flattened to a rank-1 vector. Must have the same number of elements as lhs. **Returns:** A symbolic tensor representing the [outer product](\[https://en.wikipedia.org/wiki/Outer_product]\(https://en.wikipedia.org/wiki/Outer_product\)) of the two input vectors. It will have rank 2, with the dimension sizes being the number of elements of lhs and rhs respectively. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `pad()` {#max.graph.ops.pad} > max.graph.ops.pad(input, paddings, mode='constant', value=0) Pads a tensor with constant values. Adds padding to the input tensor using the specified padding values. Currently only constant padding mode is supported. **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor to pad. * **paddings** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Sequence of padding values. The padding values are applied symmetrically to each dimension. For a tensor with rank N, paddings should contain 2\*N values: [pad\_before\_dim0, pad\_after\_dim0, pad\_before\_dim1, pad\_after\_dim1, …]. * **mode** ([`Literal`](https://docs.python.org/3/library/typing.html#typing.Literal) `[` `'constant'` `]` ) – The padding mode. Currently only “constant” is supported. * **value** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The constant value to use for padding. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `permute()` {#max.graph.ops.permute} > max.graph.ops.permute(x, dims) Permutes all dimensions of a symbolic tensor. **Parameters:** * **input** – The input symbolic tensor to transpose. * **dims** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The desired ordering of the dimensions in the output tensor. * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Returns:** A new symbolic tensor with the dimensions permuted to match the passed in order. It has the same elements and dtype, but the order of the elements is different according to the permutation. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `pow()` {#max.graph.ops.pow} > max.graph.ops.pow(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `print()` {#max.graph.ops.print} > max.graph.ops.print(value, label='debug\_tensor') Prints the value of a tensor or a string during graph execution. This function is used to output the current value of a tensor and is primarily used for debugging purposes within the context of the Max Engine and its graph execution framework. This is particularly useful to verify the intermediate results of your computations are as expected. By printing the tensor values, you can visualize the data flowing through the graph, which helps in understanding how the operations are transforming the data. When labeling the function you can assign the output, making it easier to identify which tensor’s value is being printed, especially when there are multiple print statements in a complex graph. ```python def add_tensors(a: np.ndarray, b: np.ndarray) -> dict[str, Any]: input_type = TensorType(dtype=DType.float32, shape=(1,), device=DeviceRef.CPU()) with Graph( "simple_add_graph", input_types=(input_type, input_type) ) as graph: lhs, rhs = graph.inputs out = ops.add(lhs, rhs) ops.print(out, label="addition_output") # Pass the output tensor here graph.output(out) print("final graph:", graph) ``` **Parameters:** * **value** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The value to print. Can be either a string or a TensorValue. * **label** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – A label to identify the printed value. Defaults to `debug_tensor`. ### `qmatmul()` {#max.graph.ops.qmatmul} > max.graph.ops.qmatmul(encoding, config, lhs, \*rhs) Performs matrix multiplication between floating point and quantized tensors. This quantizes the `lhs` floating point value to match the encoding of the `rhs` quantized value, performs matmul, and then dequantizes the result. Beware that, compared to a regular matmul op, this one expects the `rhs` value to be transposed. For example, if the `lhs` shape is \[32, 64], and the quantized `rhs` shape is also `[32, 64]`, then the output shape is `[32, 32]`. That is, this function returns the result from: > dequantize(quantize(lhs) @ transpose(rhs)) The last two dimensions in `lhs` are treated as matrices and multiplied by `rhs` (which must be a 2D tensor). Any remaining dimensions in `lhs` are broadcast dimensions. NOTE: Currently this supports Q4\_0, Q4\_K, and Q6\_K encodings only. **Parameters:** * **encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) ) – The quantization encoding to use. * **lhs** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The non-quantized, left-hand-side of the matmul. * **\*rhs** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The transposed and quantized right-hand-side of the matmul and auxiliary tensor (if has). Must be rank 2 and in a supported \[quantization encoding] (/max/api/mojo/graph/quantization/). * **config** ([`QuantizationConfig`](quantization.md#max.graph.quantization.QuantizationConfig) `|` `None` ) **Returns:** The dequantized result (a floating point tensor). **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `range()` {#max.graph.ops.range} > max.graph.ops.range(start, stop, step, out\_dim=None, device=cpu:0, dtype=float32) Creates a sequence of numbers. The sequence goes from start with increments of size step up to (but not including) stop. All arguments are mandatory and must have the same element type. Note the following restrictions on input values: 1. step must be non-zero 2. stop - start must be zero or have the same sign as step **Parameters:** * **start** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The start of the range to generate. * **stop** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The range will be generated up to, but not including, this value. * **step** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The step size for the range. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` `None` ) – The expected output dimensions returned by the range op. These will be assert at graph execution time to be correct. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – Device of the result tensor. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) **Returns:** A symbolic tensor value containing the defined range of values. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `rebind()` {#max.graph.ops.rebind} > max.graph.ops.rebind(x, shape, message='', layout=None) Rebinds a symbolic tensor to a specified set of dimensions. This does not mutate the symbolic tensor passed in, but instead adds a runtime assert that the input symbolic shape is equivalent to `out_dims` shape. For example, if the input tensor shape has dynamic/unknown sizes, this will assert a fixed sizes that may be required for a subsequent operation. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to rebind. * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – The symbolic shape to assert for `x`, as a list of [`Dim`](/max/api/python/graph/type/Dim) values. * **message** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The message printed if the rebind fails at runtime. * **layout** (`FilterLayout` `|` `None` ) – A layout of the weights used by some operations like conv. **Returns:** A symbolic tensor with the same elements and shape as the given tensor, but with the symbolic shape asserted to `out_dims`. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `relu()` {#max.graph.ops.relu} > max.graph.ops.relu(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `repeat_interleave()` {#max.graph.ops.repeat_interleave} > max.graph.ops.repeat\_interleave(x, repeats, axis=None, out\_dim=None) Repeats elements of a tensor along the given dimension. Modeled after `torch.repeat_interleave`, with the constraint that For example, given `repeats=2` and the following input: ```python ## Input tensor with shape (2, 2) input = TensorValue(x) # Contains [[1.0, 2.0], [3.0, 4.0]] ``` `repeat_interleave` with `axis=0`: ```python ## Output tensor with shape (4, 2) output = repeat_interleave(input, repeats=2, axis=0) ## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]] ``` `repeat_interleave` with `axis=1`: ```python ## Output tensor with shape (2, 4) output = repeat_interleave(input, repeats=2, axis=1) ## Contains [[1.0, 1.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0]] ``` `repeat_interleave` with `axis=None` (the default): `repeat_interleave` with `repeats=[2, 3]` and `axis=0`: ```python repeat_value = TensorValue([2, 3]) ## Output tensor with shape (5, 2) output = repeat_interleave(input, repeats=repeat_value, axis=0) ## Contains [[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]] ``` ```python ## Output tensor with shape (8,) output = repeat_interleave(input, repeats=2) # axis = None ## Contains [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0] ``` **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor. * **repeats** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The number of repetitions for each element. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The dimension along which to repeat values. If axis is not specified or None (the default), flatten the input array and repeat the flattened values. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` `None` ) **Returns:** A symbolic tensor with the elements interleaved. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If `repeats` non-positive or if `axis` is out of range. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `reshape()` {#max.graph.ops.reshape} > max.graph.ops.reshape(x, shape) Reshapes a symbolic tensor. The number and order of the elements in the tensor is unchanged. In other words, if you were to iterate over elements in the tensor by major dimension to minor dimension, the iteration order would stay the same. If a value of -1 is present in the shape, that dimension becomes an automatically calculated dimension collecting all unspecified dimensions. Its length becomes the number of elements in the original tensor divided by the product of elements of the reshape. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to reshape. This tensor may not contain any dynamic dimensions. * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – The new shape as a list of dimensions. Dynamic dimensions are not allowed. A single dimension may be -1. **Returns:** A symbolic tensor with the same elements as the original tensor, but in a new shape. Its symbolic shape is the same as `shape`. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – if input and target shapes’ number of elements mismatch. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `resize()` {#max.graph.ops.resize} > max.graph.ops.resize(input, shape, interpolation=InterpolationMode.BILINEAR) Resize the input tensor to the given shape. This function resizes a tensor using the specified interpolation method. The tensor is expected to have NCHW format (batch, channels, height, width). **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor to resize. Must have rank 4 in NCHW format. * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – Desired output shape of length 4 corresponding to (N, C, H, W). * **interpolation** ([`InterpolationMode`](#max.graph.ops.InterpolationMode) ) – Desired interpolation enum defined by InterpolationMode. Default is InterpolationMode.BILINEAR. Currently only BICUBIC is supported. **Returns:** A resized tensor with the shape specified by the shape argument. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the input doesn’t have rank 4, shape has wrong number of elements, or unsupported interpolation mode is specified. * [**NotImplementedError**](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If single integer size or non-BICUBIC interpolation mode is specified. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `round()` {#max.graph.ops.round} > max.graph.ops.round(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `rsqrt()` {#max.graph.ops.rsqrt} > max.graph.ops.rsqrt(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `scatter()` {#max.graph.ops.scatter} > max.graph.ops.scatter(input, updates, indices, axis=-1) Creates a new symbolic tensor where the updates are written to input according to indices. **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to write elements to. * **updates** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – A symbolic tensor of elements to write to input. * **indices** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The positions in input to update. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis along which indices indexes into. **Returns:** A new symbolic tensor representing the result of the scatter operation. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `shape_to_tensor()` {#max.graph.ops.shape_to_tensor} > max.graph.ops.shape\_to\_tensor(shape) Converts a shape to a tensor. This is useful for using a shape attribute in an op that expects a tensor value. **Parameters:** **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – the shape attribute of a tensor value. **Returns:** The TensorValue containing the same value as shape. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ```python x = ops.constant(np.zeros((1,)), DType.int64, device=DeviceRef.CPU()) result = ops.stack([ x, ops.shape_to_tensor(x.shape), ]) ## TensorValue(dtype=int64, shape=[StaticDim(dim=2), StaticDim(dim=1)]) ``` ### `sigmoid()` {#max.graph.ops.sigmoid} > max.graph.ops.sigmoid(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `silu()` {#max.graph.ops.silu} > max.graph.ops.silu(x) Computes the elementwise silu of a symbolic tensor. Creates a new op node to compute the elementwise silu of a symbolic tensor and adds it to the graph, returning the symbolic result. `silu` is defined as `silu(x) = x * sigmoid(x)`. **Parameters:** * **value** – The symbolic tensor to use as the input to the silu computation. * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) **Returns:** A new symbolic tensor value representing the output of the silu value computation. **Raises:** **Error** – If the symbol doesn’t represent a tensor value. ### `sin()` {#max.graph.ops.sin} > max.graph.ops.sin(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `slice_tensor()` {#max.graph.ops.slice_tensor} > max.graph.ops.slice\_tensor(x, indices) Slices out a subtensor view of the input tensor based on indices. The semantics of [`slice_tensor()`](#max.graph.ops.slice_tensor) follow NumPy slicing semantics with the following restrictions: * Slice indices must not index out of `[-dim - 1, dim - 1]` for negative step, or `[-dim, dim]` for positive step. ```python ## Reverse a tensor. slice_tensor(x, [slice(None, None, -1)]) ## Unsqueeze the second last dimension of a tensor. slice_tensor(x, [..., None, slice(None)]) ``` **Returns:** The sliced subtensor of x. **Parameters:** * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) * **indices** (`SliceIndices` ) **Return type:** [TensorValue](TensorValue.md#max.graph.TensorValue) ### `softmax()` {#max.graph.ops.softmax} > max.graph.ops.softmax(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `split()` {#max.graph.ops.split} > max.graph.ops.split(x, split\_sizes, axis=0) Splits the input tensor into multiple tensors along a given dimension. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to split. * **split\_sizes** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) – Sizes of each output tensor. Must add up to the split dimension axis. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Dimension to split the input tensor. **Returns:** A list of tensors with the same length as split\_sizes, where each tensor has the same shape as the input except along the split dimension axis, where the size is given by the corresponding element in split\_sizes. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*TensorValue*](TensorValue.md#max.graph.TensorValue)] ### `sqrt()` {#max.graph.ops.sqrt} > max.graph.ops.sqrt(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `squeeze()` {#max.graph.ops.squeeze} > max.graph.ops.squeeze(x, axis) Removes a size-1 dimension from a symbolic tensor. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to squeeze. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension to remove from the input’s shape. If negative, this indexes from the end of the tensor. For example, `squeeze(v, -1)` squeezes the last dimension. **Returns:** A symbolic tensor with the same number of elements as the input tensor, and whose rank is 1 less than the rank of the input tensor. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `stack()` {#max.graph.ops.stack} > max.graph.ops.stack(values, axis=0) Stacks a list of tensors along a new axis. **Parameters:** * **values** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` `Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `]` ) – A list of symbolic tensor values. Each tensor must have the same dtype and rank, and must have the same dimension size for each dimension. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis to concatenate along. If negative, indexes relative to the end of the tensor shape *plus 1*. For instance, `stack(vs, -1)` will create and stack along a new axis as the last dimension, aad `stack(vs, -2)` will create and stack along a new dimension which is inserted immediately before the last dimension. **Returns:** A new symbolic tensor representing the result of the stack. It will have rank `n+1` where `n` is the rank of each input tensor. Its size on each dimension other than `axis` will be the same as each input tensors’, with the new axis inserted. Along the new dimension it will have size `len(values)`. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `sub()` {#max.graph.ops.sub} > max.graph.ops.sub(lhs, rhs) **Parameters:** * **lhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **rhs** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `sum()` {#max.graph.ops.sum} > max.graph.ops.sum(x, axis=-1) Reduces a symbolic tensor using a sum operation. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor for the operation. * **axis** – The axis along which to compute the reduction. If negative, indexes from the last dimension. For example, a value of -1 will compute the reduction along the last dimension. **Returns:** A symbolic tensor representing the result of the sum operation. The tensor will have the same rank as the input tensor, and the same shape except along the `axis` dimension which will have size 1. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `tanh()` {#max.graph.ops.tanh} > max.graph.ops.tanh(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `tile()` {#max.graph.ops.tile} > max.graph.ops.tile(x, repeats) Returns a new Tensor as the result of copying the input tensor N\_i times on each dimension, where N\_i = repeats\[i]. The i-th dimension of output shape will be the ith dimension of input shape multiplied by N\_i. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **repeats** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `top_k()` {#max.graph.ops.top_k} > max.graph.ops.top\_k(input, k, axis=-1) Returns tensor with only top K values along given axis. **Parameters:** * **input** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input tensor from which to select top k. * **k** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of values to select from input. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The axis from which to select top k. **Returns:** Top K values, Top K indices **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[*TensorValue*](TensorValue.md#max.graph.TensorValue), [*TensorValue*](TensorValue.md#max.graph.TensorValue)] ### `transfer_to()` {#max.graph.ops.transfer_to} > max.graph.ops.transfer\_to(x, device) Device-to-Device transfer operation. This op transfers the input tensor from its current device over to another. A device represents a computation unit, like CPU, GPU, etc. This op is useful for instance when working with accelerators, like GPU, where for instance one may need to move data from GPU to GPU, or from one GPU to CPU. **Parameters:** * **x** ([`TensorValue`](TensorValue.md#max.graph.TensorValue) ) – The input tensor to transfer. * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) – The device to transfer to. **Returns:** A tensor transferred to device specified. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `transpose()` {#max.graph.ops.transpose} > max.graph.ops.transpose(x, axis\_1, axis\_2) Transposes two axes of a symbolic tensor. For more information, see [`transpose()`](TensorValue.md#max.graph.TensorValue.transpose). **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to transpose. * **axis\_1** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – One of the two axes to transpose. If negative, this indexes from the end of the tensor. For example, `transpose(v, -1, -2)` transposes the last two axes. * **axis\_2** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The other axis to transpose. May also be negative to index from the end of the tensor. **Returns:** A new symbolic tensor with the two specified axes transposed. It has the same elements and dtype, but the order of the elements is different according to the transposition. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `trunc()` {#max.graph.ops.trunc} > max.graph.ops.trunc(x) **Parameters:** **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `unsqueeze()` {#max.graph.ops.unsqueeze} > max.graph.ops.unsqueeze(x, axis) Inserts a size-1 dimension into a symbolic tensor. **Parameters:** * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The input symbolic tensor to unsqueeze. * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The index at which to insert a new dimension into the input’s shape. Elements at that index or higher are shifted back. If negative, it indexes relative *1 plus* the rank of the tensor. For example, `unsqueeze(v, -1)` adds a new dimension at the end, and `unsqueeze(v, -2)` inserts the dimension immediately before the last dimension. **Returns:** A symbolic tensor with the same number of elements as the input tensor, whose rank is 1 larger than the rank of the input tensor. The result’s shape at the `axis` dimension is a static dimension of size 1. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `where()` {#max.graph.ops.where} > max.graph.ops.where(condition, x, y) Returns `condition ? x : y` (element-wise), where `cond`, `x` and `y` are input tensors. **Parameters:** * **condition** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The condition tensor to use for selecting elementwise values. This tensor must have a boolean dtype. * **x** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – If the condition is true at a position, the value from the same position in this tensor will be selected. * **y** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `|` [`Shape`](type.md#max.graph.type.Shape) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – If the condition is false at a position, the value from the same position in this tensor will be selected. **Returns:** A new symbolic tensor holding either values from either `x` or `y`, based on the elements in condition. **Return type:** [*TensorValue*](TensorValue.md#max.graph.TensorValue) ### `while_loop()` {#max.graph.ops.while_loop} > max.graph.ops.while\_loop(initial\_values, predicate, body) Execute a loop until the predicate evaluates to false. Both the predicate and body functions must take in as arguments the same number and types of values as specified in the init\_args. The predication function must return only a boolean scalar tensor of type `DType.bool`. The body function must return a list of values matching the types of init\_args. The following example demonstrates a basic while loop with a single argument: ```python from max.graph import Graph, ops from max.dtype import DType with Graph("while_loop_example") as g: x = ops.constant(0, dtype=DType.int32, device=DeviceRef.CPU()) def pred(x): return x **Parameters:** * **initial\_values** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Value`](Value.md#max.graph.Value) `]` `|` [`Value`](Value.md#max.graph.Value) ) – Initial values for loop arguments. Must be non-empty. * **predicate** ([`Callable`](https://docs.python.org/3/library/typing.html#typing.Callable) `[` `[` `...` `]` `,` [`TensorValue`](TensorValue.md#max.graph.TensorValue) `]` ) – Callable that takes loop arguments and returns a boolean scalar tensor of type `DType.bool` determining loop continuation. * **body** ([`Callable`](https://docs.python.org/3/library/typing.html#typing.Callable) `[` `[` `...` `]` `,` [`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Value`](Value.md#max.graph.Value) `]` `]` ) – Callable that takes loop arguments and returns updated values matching the types of init\_args. **Returns:** List of output values from the final loop iteration. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If init\_args is empty. * [**NotImplementedError**](https://docs.python.org/3/library/exceptions.html#NotImplementedError) – If any init\_arg is a `BufferValue`. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*TensorValue*](TensorValue.md#max.graph.TensorValue)] ##### NOTE Buffer operations are currently not supported. --- ## quantization APIs to quantize graph tensors. This package includes a comprehensive set of tools for working with quantized models in MAX Graph. It defines supported quantization encodings, configuration parameters that control the quantization process, and block parameter specifications for different quantization formats. The module supports various quantization formats including 4-bit, 5-bit, and 6-bit precision with different encoding schemes. It also provides support for GGUF-compatible formats for interoperability with other frameworks. ## `BlockParameters` {#max.graph.quantization.BlockParameters} > class max.graph.quantization.BlockParameters(elements\_per\_block, block\_size) Parameters describing the structure of a quantization block. Block-based quantization stores elements in fixed-size blocks. Each block contains a specific number of elements in a compressed format. **Parameters:** * **elements\_per\_block** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **block\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `block_size` {#max.graph.quantization.BlockParameters.block_size} > block\_size: [int](https://docs.python.org/3/library/functions.html#int) ### `elements_per_block` {#max.graph.quantization.BlockParameters.elements_per_block} > elements\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) ## `QuantizationConfig` {#max.graph.quantization.QuantizationConfig} > class max.graph.quantization.QuantizationConfig(quant\_method, bits, group\_size, desc\_act=False, sym=False) Configuration for specifying quantization parameters that affect inference. These parameters control how tensor values are quantized, including the method, bit precision, grouping, and other characteristics that affect the trade-off between model size, inference speed, and accuracy. **Parameters:** * **quant\_method** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **bits** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **group\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **desc\_act** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **sym** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `bits` {#max.graph.quantization.QuantizationConfig.bits} > bits: [int](https://docs.python.org/3/library/functions.html#int) ### `desc_act` {#max.graph.quantization.QuantizationConfig.desc_act} > desc\_act: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `group_size` {#max.graph.quantization.QuantizationConfig.group_size} > group\_size: [int](https://docs.python.org/3/library/functions.html#int) ### `quant_method` {#max.graph.quantization.QuantizationConfig.quant_method} > quant\_method: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `sym` {#max.graph.quantization.QuantizationConfig.sym} > sym: [bool](https://docs.python.org/3/library/functions.html#bool) = False ## `QuantizationEncoding` {#max.graph.quantization.QuantizationEncoding} > class max.graph.quantization.QuantizationEncoding(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Quantization encodings supported by MAX Graph. Each encoding represents a different method of quantizing model weights with specific trade-offs between compression ratio, accuracy, and computational efficiency. ### `GPTQ` {#max.graph.quantization.QuantizationEncoding.GPTQ} > GPTQ = 'GPTQ' ### `Q4_0` {#max.graph.quantization.QuantizationEncoding.Q4_0} > Q4\_0 = 'Q4\_0' ### `Q4_K` {#max.graph.quantization.QuantizationEncoding.Q4_K} > Q4\_K = 'Q4\_K' ### `Q5_K` {#max.graph.quantization.QuantizationEncoding.Q5_K} > Q5\_K = 'Q5\_K' ### `Q6_K` {#max.graph.quantization.QuantizationEncoding.Q6_K} > Q6\_K = 'Q6\_K' ### `block_parameters` {#max.graph.quantization.QuantizationEncoding.block_parameters} > property block\_parameters: [BlockParameters](#max.graph.quantization.BlockParameters) Gets the block parameters for this quantization encoding. **Returns:** The parameters describing how elements are organized and encoded in blocks for this quantization encoding. **Return type:** [BlockParameters](#max.graph.quantization.BlockParameters) ### `block_size` {#max.graph.quantization.QuantizationEncoding.block_size} > property block\_size: [int](https://docs.python.org/3/library/functions.html#int) Number of bytes in encoded representation of block. All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of bytes resulting after encoding a single block. **Returns:** Size in bytes of each encoded quantization block. **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `elements_per_block` {#max.graph.quantization.QuantizationEncoding.elements_per_block} > property elements\_per\_block: [int](https://docs.python.org/3/library/functions.html#int) Number of elements per block. All quantization types currently supported by MAX Graph are block-based: groups of a fixed number of elements are formed, and each group is quantized together into a fixed-size output block. This value is the number of elements gathered into a block. **Returns:** Number of original tensor elements in each quantized block. **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `is_gguf` {#max.graph.quantization.QuantizationEncoding.is_gguf} > property is\_gguf: [bool](https://docs.python.org/3/library/functions.html#bool) Checks if this quantization encoding is compatible with GGUF format. GGUF is a format for storing large language models and compatible quantized weights. **Returns:** True if this encoding is compatible with GGUF, False otherwise. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `name` {#max.graph.quantization.QuantizationEncoding.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) Gets the lowercase name of the quantization encoding. **Returns:** Lowercase string representation of the quantization encoding. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) --- ## type Library for graph value types. ## `AlgebraicDim` {#max.graph.type.AlgebraicDim} > class max.graph.type.AlgebraicDim(value) An algebraic tensor dimension to enable expressions over symbolic dimensions. That is, any expression over a symbolic dimension returns `AlgebraicDim`. Furthermore, algebraic dimensions automatically simplify into a canonical form. The following example demonstrates how to create and use algebraic dimensions with symbolic values: ```python from max.graph import AlgebraicDim, Dim isinstance(Dim("batch") * 5, AlgebraicDim) # Returns True print(Dim("batch") * 5) # Outputs: batch * 5 -Dim("x") - 4 == -(Dim("x") + 4) # Returns True ``` Converts valid input values to Dim. **Parameters:** **attr** (`ParamOperatorAttr` ) ### `apply()` {#max.graph.type.AlgebraicDim.apply} > classmethod apply(op, \*operands) **Parameters:** * **op** (`POC` ) * **operands** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) ) ### `attr` {#max.graph.type.AlgebraicDim.attr} > attr: ParamOperatorAttr ### `from_mlir()` {#max.graph.type.AlgebraicDim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`. **Parameters:** * **dim\_attr** – The MLIR Attribute object to parse into a dimension. * **attr** (`TypedAttr` ) **Returns:** The dimension represented by the MLIR Attr value. **Return type:** [Dim](#max.graph.type.Dim) ### `to_mlir()` {#max.graph.type.AlgebraicDim.to_mlir} > to\_mlir() Creates an mlir.Attribute representing this dimension. This is used internally when constructing tensor MLIR types. **Returns:** An mlir.Attribute in the context representing the dimension. **Return type:** *ParamOperatorAttr* ## `DeviceKind` {#max.graph.type.DeviceKind} > class max.graph.type.DeviceKind(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) A device type representation. ### `CPU` {#max.graph.type.DeviceKind.CPU} > CPU = 'cpu' ### `GPU` {#max.graph.type.DeviceKind.GPU} > GPU = 'gpu' ### `from_string()` {#max.graph.type.DeviceKind.from_string} > static from\_string(txt) **Return type:** [*DeviceKind*](#max.graph.type.DeviceKind) ## `DeviceRef` {#max.graph.type.DeviceRef} > class max.graph.type.DeviceRef(device\_type, id=0) A symbolic device representation. DeviceRef type representation consists of a DeviceKind and an id. This is a direct representation of the device attribute in mlir. The following example demonstrates how to create and use device references: ```python from max.graph import DeviceRef gpu_device = DeviceRef.GPU() print(gpu_device) # Outputs: gpu:0 # Create a CPU device with specific id cpu_device = DeviceRef.CPU(id=1) print(cpu_device) # Outputs: cpu:1 ``` **Parameters:** * **device\_type** ([`DeviceKind`](#max.graph.type.DeviceKind) ) * **id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `CPU()` {#max.graph.type.DeviceRef.CPU} > static CPU(id=0) Static Method for creating a CPU device. **Parameters:** **id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*DeviceRef*](#max.graph.type.DeviceRef) ### `GPU()` {#max.graph.type.DeviceRef.GPU} > static GPU(id=0) Static Method for creating a GPU device. **Parameters:** **id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*DeviceRef*](#max.graph.type.DeviceRef) ### `device_type` {#max.graph.type.DeviceRef.device_type} > device\_type: [DeviceKind](#max.graph.type.DeviceKind) ### `from_device()` {#max.graph.type.DeviceRef.from_device} > static from\_device(device) **Parameters:** **device** ([`Device`](../driver.md#max.driver.Device) ) **Return type:** [*DeviceRef*](#max.graph.type.DeviceRef) ### `from_mlir()` {#max.graph.type.DeviceRef.from_mlir} > static from\_mlir(attr) Returns a device from mlir attribute **Parameters:** **attr** (`DeviceRefAttr` ) **Return type:** [*DeviceRef*](#max.graph.type.DeviceRef) ### `id` {#max.graph.type.DeviceRef.id} > id: [int](https://docs.python.org/3/library/functions.html#int) ### `is_cpu()` {#max.graph.type.DeviceRef.is_cpu} > is\_cpu() Returns true if the device is a CPU device. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_gpu()` {#max.graph.type.DeviceRef.is_gpu} > is\_gpu() Returns true if the device is a GPU device. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `to_device()` {#max.graph.type.DeviceRef.to_device} > to\_device() Convert device reference to a concrete driver Device. **Return type:** [*Device*](../driver.md#max.driver.Device) ### `to_mlir()` {#max.graph.type.DeviceRef.to_mlir} > to\_mlir() Returns a mlir attribute representing device. **Return type:** *DeviceRefAttr* ## `Dim` {#max.graph.type.Dim} > class max.graph.type.Dim(value) A tensor dimension. Tensor dimensions can be one of three types: * **Static**: Known size * **Symbolic**: Unknown size but named * **Algebraic**: Unknown size has an algebraic expression In most cases, you don’t need to work with a `Dim` directly. Instead, use conversion constructors: ```python from max.graph import Dim, TensorType, DeviceRef tensor_type = TensorType(DType.int64, ("batch", 10), device=DeviceRef.CPU()) ``` This creates a tensor type with three dimensions: * A symbolic “batch” dimension * A static dimension of size 10 For explicit dimension construction, use the following helpers: ```python from max.graph import Dim some_dims = [ SymbolicDim("batch"), StaticDim(5), AlgebraicDim(Dim("batch") + 1), ] ``` Constraining tensor dimensions is one important way to improve model performance. If tensors have unknown dimensions, we can’t optimize them as aggressively. Symbolic tensors allow the compiler to learn constraints on a specific dimension (eg. if 2 inputs have the same batch dimension), but static dims are the easiest to optimize and therefore the easiest to create and work with. Converts valid input values to Dim. **Parameters:** **value** (`DimLike` ) ### `from_mlir()` {#max.graph.type.Dim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`. **Parameters:** * **dim\_attr** – The MLIR Attribute object to parse into a dimension. * **attr** (`TypedAttr` ) **Returns:** The dimension represented by the MLIR Attr value. **Return type:** [Dim](#max.graph.type.Dim) ### `to_mlir()` {#max.graph.type.Dim.to_mlir} > to\_mlir() Creates an `mlir.Attribute` representing this dimension. This is used internally when constructing tensor MLIR types. **Returns:** An `mlir.Attribute` in the context representing the dimension. **Return type:** *TypedAttr* ## `Shape` {#max.graph.type.Shape} > class max.graph.type.Shape(dims=()) **Parameters:** **dims** (`ShapeLike` ) ### `from_mlir()` {#max.graph.type.Shape.from_mlir} > classmethod from\_mlir(attr) **Parameters:** **attr** (`TypedAttr` ) **Return type:** [*Shape*](#max.graph.type.Shape) ### `rank` {#max.graph.type.Shape.rank} > property rank ### `static_dims` {#max.graph.type.Shape.static_dims} > property static\_dims: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] Returns all static dims in the shape as a list of integers. ### `to_mlir()` {#max.graph.type.Shape.to_mlir} > to\_mlir() **Return type:** *ShapeAttr* ## `StaticDim` {#max.graph.type.StaticDim} > class max.graph.type.StaticDim(value) A static tensor dimension. Static tensor dimensions will always have exactly the same value, and are key to good model performance. The following example shows how static dimensions can be created implicitly: ```python from max.graph import TensorType from max.dtype import DType tensor = TensorType(DType.int64, (4, 5)) # This creates a tensor with 2 static dimensions: 4 and 5 respectively ``` Converts valid input values to Dim. **Parameters:** **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `dim` {#max.graph.type.StaticDim.dim} > dim: [int](https://docs.python.org/3/library/functions.html#int) The size of the static dimension. ### `from_mlir()` {#max.graph.type.StaticDim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`. **Parameters:** * **dim\_attr** – The MLIR Attribute object to parse into a dimension. * **attr** (`TypedAttr` ) **Returns:** The dimension represented by the MLIR Attr value. **Return type:** [*Dim*](#max.graph.type.Dim) ### `to_mlir()` {#max.graph.type.StaticDim.to_mlir} > to\_mlir() Creates an `mlir.Attribute` representing this dimension. This is used internally when constructing tensor MLIR types. **Returns:** An `mlir.Attribute` in the context representing the dimension. **Return type:** *IntegerAttr* ## `SymbolicDim` {#max.graph.type.SymbolicDim} > class max.graph.type.SymbolicDim(value) A symbolic tensor dimension. Symbolic dimensions represent named dimensions in MO tensor types. Symbolic dimensions don’t have a static value, but they allow a readable name to understand what’s going on in the model IR better, and they also allow users to hint to the compiler that two dimensions will have the same value, which can often allow important speedups. In tensor type notation: ```default !mo.tensor ``` The first and second dimensions are named `batch` and `x` respectively. Creating a `SymbolicDim`: ```python dim = SymbolicDim("name") ``` Using `SymbolicDim` in a [`TensorType`](#max.graph.type.TensorType): ```python tensor_type = TensorType(DType.bool, (SymbolicDim("batch"), SymbolicDim("x"), 10)) ``` Converts valid input values to Dim. **Parameters:** **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) ### `from_mlir()` {#max.graph.type.SymbolicDim.from_mlir} > static from\_mlir(attr) Constructs a dimension from an `mlir.Attribute`. **Parameters:** * **dim\_attr** – The MLIR Attribute object to parse into a dimension. * **attr** (`TypedAttr` ) **Returns:** The dimension represented by the MLIR Attr value. **Return type:** [Dim](#max.graph.type.Dim) ### `name` {#max.graph.type.SymbolicDim.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the dimension. ### `to_mlir()` {#max.graph.type.SymbolicDim.to_mlir} > to\_mlir() Creates an `mlir.Attribute` representing this dimension. This is used internally when constructing tensor MLIR types. **Returns:** An `mlir.Attribute` in the context representing the dimension. **Return type:** *ParamDeclRefAttr* ## `TensorType` {#max.graph.type.TensorType} > class max.graph.type.TensorType(dtype, shape, device) A symbolic [`TensorType`](#max.graph.type.TensorType). This is not an eager tensor type! This contains no actual data, but instead represents the type of a value at some point in time during model execution. Most internal values in a model will be tensors. This type represents their element type (`dtype`) and dimensions (`dims`) at a specific point during model computation. It allows us to do some optimistic optimizations and shape inference during graph construction, and to provide more detailed shape information to the compiler for further optimization passes. The following example shows how to create a tensor type with static dimensions and access its properties: ```python from max.graph import TensorType from max.dtype import DType # Create a tensor type with float32 elements and static dimensions 2x3 tensor_type = TensorType(DType.float32, (2, 3)) print(tensor_type.dtype) # Outputs: DType.float32 print(tensor_type.shape) # Outputs: [2, 3] ``` It can also represent a fully dynamic rank tensor. The presence of dynamic rank tensors in a graph will often degrade performance dramatically and prevents many classes of optimizations. An optional device (`device`) can also be provided to indicate the explicit device the tensor is associated with. Constructs a tensor type. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The element type of the tensor data. * **dims** – The shape dimensions of the tensor. The number of dims is the rank of the tensor. * **shape** ([`Shape`](#max.graph.type.Shape) ) * **device** ([`DeviceRef`](#max.graph.type.DeviceRef) ) ### `as_buffer()` {#max.graph.type.TensorType.as_buffer} > as\_buffer() Returns the analogous buffer type. **Return type:** *BufferType* ### `from_mlir()` {#max.graph.type.TensorType.from_mlir} > classmethod from\_mlir(type) Constructs a tensor type from an MLIR type. **Parameters:** * **t** – The MLIR Type object to parse into a tensor type. * **type** (`TensorType` ) **Returns:** The tensor type represented by the MLIR Type value. **Return type:** [*TensorType*](#max.graph.type.TensorType) ### `to_mlir()` {#max.graph.type.TensorType.to_mlir} > to\_mlir() Converts to an `mlir.Type` instance. **Returns:** An `mlir.Type` in the specified Context. **Return type:** *TensorType* ## `Type` {#max.graph.type.Type} > class max.graph.type.Type Represents any possible type for Graph values. Every Value in the Graph has a Type, and that type is represented by an Type. This type may be inspected to get finer-grained types and learn more about an individual Value. The following example shows how to work with types in a graph: ```python from max.graph import Graph, TensorType from max.dtype import DType with Graph() as g: # Create a tensor constant with a specific type tensor_type = TensorType(DType.float32, [2, 3]) # The type can be inspected to get information about the value print(f"Tensor element type: {tensor_type.dtype}") # Outputs: DType.float32 print(f"Tensor shape: {tensor_type.shape}") # Outputs: [2, 3] ``` ### `from_mlir()` {#max.graph.type.Type.from_mlir} > static from\_mlir(t) Constructs a type from an MLIR type. **Parameters:** **t** (`MlirType` ) – The MLIR Type object to parse into a type. **Returns:** The type represented by the MLIR Type value. **Return type:** [*Type*](#max.graph.type.Type) ### `to_mlir()` {#max.graph.type.Type.to_mlir} > to\_mlir() Converts to an `mlir.Type` instance. **Returns:** An `mlir.Type` in the specified Context. **Return type:** *MlirType* --- ## weights APIs for loading weights into a graph. ## `GGUFWeights` {#max.graph.weights.GGUFWeights} > class max.graph.weights.GGUFWeights(source, tensors=None, prefix='', allocated=None) Creates a GGUF weights reader. **Parameters:** * **source** (`Union` `[` `PathLike` `,` `gguf.GGUFReader` `]` ) – Path to a GGUF file or a GGUFReader object. * **tensors** – List of tensors in the GGUF checkpoint. * **prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Weight name or prefix. * **allocated** – Dictionary of allocated values. ### `allocate()` {#max.graph.weights.GGUFWeights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates and optionally validates a new Weight. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) **Return type:** [*Weight*](Weight.md#max.graph.Weight) ### `allocated_weights` {#max.graph.weights.GGUFWeights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[\_ScalarType\_co]]] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.GGUFWeights.data} > data() Returns data loaded from the weights at the current prefix. **Raises:** **KeyError if the current prefix isn't present in the checkpoint.** – **Return type:** [*WeightData*](#max.graph.weights.WeightData) ### `exists()` {#max.graph.weights.GGUFWeights.exists} > exists() Returns whether a weight with this exact name exists. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `items()` {#max.graph.weights.GGUFWeights.items} > items() Iterate through all allocable weights that start with the prefix. ### `name` {#max.graph.weights.GGUFWeights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ### `raw_tensor()` {#max.graph.weights.GGUFWeights.raw_tensor} > raw\_tensor() Returns the numpy tensor corresponding to this weights object. **Raises:** **KeyError if this weights object isn't a tensor.** – **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any)]] ## `PytorchWeights` {#max.graph.weights.PytorchWeights} > class max.graph.weights.PytorchWeights(filepath, tensor\_infos=None, prefix='', allocated=None) **Parameters:** * **filepath** (`PathLike` ) * **tensor\_infos** (`Optional` `[` [`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `Any` `]` `]` ) * **prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) ### `allocate()` {#max.graph.weights.PytorchWeights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates and optionally validates a new Weight. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) **Return type:** [*Weight*](Weight.md#max.graph.Weight) ### `allocated_weights` {#max.graph.weights.PytorchWeights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[\_ScalarType\_co]]] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.PytorchWeights.data} > data() **Return type:** [*WeightData*](#max.graph.weights.WeightData) ### `dtype` {#max.graph.weights.PytorchWeights.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) The current weight dtype, if this weight exists. ### `exists()` {#max.graph.weights.PytorchWeights.exists} > exists() **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `items()` {#max.graph.weights.PytorchWeights.items} > items() Iterate through all allocable weights that start with the prefix. ### `name` {#max.graph.weights.PytorchWeights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ### `quantization_encoding` {#max.graph.weights.PytorchWeights.quantization_encoding} > property quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) The current weight quantization encoding, if this weight exists. ### `raw_tensor()` {#max.graph.weights.PytorchWeights.raw_tensor} > raw\_tensor() Returns the tensor corresponding to this weights object. **Raises:** **KeyError if this weights object isn't a tensor.** – **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any)]] ### `shape` {#max.graph.weights.PytorchWeights.shape} > property shape: [Shape](type.md#max.graph.type.Shape) The current weight shape, if this weight exists. ## `RandomWeights` {#max.graph.weights.RandomWeights} > class max.graph.weights.RandomWeights(\_allocated=\, \_prefix='') A class that mimics a Weights implementation with a checkpoint file. Unlike checkpoint-backed weights, this doesn’t carry a mapping from weight names to mmap’ed numpy arrays. Rather, when .allocate is called, this generates a backing NumPy array of the desired tensor spec on the fly and stores it. This is useful for generating weights from testing and using them in subcomponents that expect a weights implementation backed by a checkpoint. **Parameters:** * **\_allocated** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `]` ) * **\_prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) ### `allocate()` {#max.graph.weights.RandomWeights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates a Weight that can be added to a graph. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) **Return type:** [*Weight*](Weight.md#max.graph.Weight) ### `allocated_weights` {#max.graph.weights.RandomWeights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.RandomWeights.data} > data() Returns data loaded from the weights at the current prefix. **Raises:** **KeyError if the current prefix isn't present in the checkpoint.** – **Return type:** [*WeightData*](#max.graph.weights.WeightData) ### `exists()` {#max.graph.weights.RandomWeights.exists} > exists() Returns whether a weight with this exact name exists. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `items()` {#max.graph.weights.RandomWeights.items} > items() Iterate through all allocable weights that start with the prefix. ### `name` {#max.graph.weights.RandomWeights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ### `raw_tensor()` {#max.graph.weights.RandomWeights.raw_tensor} > raw\_tensor() Returns the numpy tensor corresponding to this weights object. **Parameters:** **dtype** – If specified, the returned array will be cast to the dtype before returning. **Raises:** **KeyError if this weights object isn't a tensor.** – **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any)]] ## `SafetensorWeights` {#max.graph.weights.SafetensorWeights} > class max.graph.weights.SafetensorWeights(filepaths, \*, tensors=None, tensors\_to\_file\_idx=None, prefix='', allocated=None, \_st\_weight\_map=None) Helper for loading weights into a graph. A weight (max.graph.Weight) is tensors in a graph which are backed by an external buffer or mmap. Generally weights are used to avoid recompiling the graph when new weights are used (like from finetuning). For large-enough constants, it might be worth using weights for fast compilation times but the graph may be less optimized. Weight classes can be used to help with graph weight allocation and naming. This protocol defines getter methods \_\_getattr\_\_ and \_\_getitem\_\_ to assist with defining names. For example, weights.a.b\[1].c.allocate(…) creates a weight with the name “a.b.1.c”. **Parameters:** * **filepaths** (`Sequence` `[` `PathLike` `]` ) * **tensors** (`Optional` `[` `Set` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` `]` ) * **tensors\_to\_file\_idx** (`Mapping` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` `None` ) * **prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **\_st\_weight\_map** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `Tensor` `]` ) ### `allocate()` {#max.graph.weights.SafetensorWeights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates a Weight that can be added to a graph. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) **Return type:** [*Weight*](Weight.md#max.graph.Weight) ### `allocate_as_bytes()` {#max.graph.weights.SafetensorWeights.allocate_as_bytes} > allocate\_as\_bytes(dtype=None) Create a Weight that can be added to the graph. Has a uint8 representation, instead of the original data type. Last dimension of the scale gets scaled by number of bytes it takes to represent the original data type. For example, \[512, 256] float32 weights become \[512, 1024] uint8 weights. Scalar weights will be interpreted as weights with shape \[1]. **Parameters:** **dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) **Return type:** [*Weight*](Weight.md#max.graph.Weight) ### `allocated_weights` {#max.graph.weights.SafetensorWeights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[\_ScalarType\_co]]] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.SafetensorWeights.data} > data() Returns data loaded from the weights at the current prefix. **Raises:** **KeyError if the current prefix isn't present in the checkpoint.** – **Return type:** [*WeightData*](#max.graph.weights.WeightData) ### `exists()` {#max.graph.weights.SafetensorWeights.exists} > exists() Returns whether a weight with this exact name exists. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `items()` {#max.graph.weights.SafetensorWeights.items} > items() Iterate through all allocable weights that start with the prefix. ### `name` {#max.graph.weights.SafetensorWeights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ### `raw_tensor()` {#max.graph.weights.SafetensorWeights.raw_tensor} > raw\_tensor() Returns the numpy tensor corresponding to this weights object. **Raises:** **KeyError if this weights object isn't a tensor.** – **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any)]] ## `WeightData` {#max.graph.weights.WeightData} > class max.graph.weights.WeightData(data, name, dtype, shape, quantization\_encoding=None) Data loaded from a checkpoint. **Parameters:** * **data** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `[` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `,` [`dtype`](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype) `[` `\_ScalarType_co` `]` `]` ) * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) * **shape** ([`Shape`](type.md#max.graph.type.Shape) ) * **quantization\_encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) ### `astype()` {#max.graph.weights.WeightData.astype} > astype(dtype) **Parameters:** **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) **Return type:** [*WeightData*](#max.graph.weights.WeightData) ### `data` {#max.graph.weights.WeightData.data} > data: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[\_ScalarType\_co]] ### `dtype` {#max.graph.weights.WeightData.dtype} > dtype: [DType](../dtype.md#max.dtype.DType) ### `from_numpy()` {#max.graph.weights.WeightData.from_numpy} > classmethod from\_numpy(arr, name) ### `name` {#max.graph.weights.WeightData.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `quantization_encoding` {#max.graph.weights.WeightData.quantization_encoding} > quantization\_encoding: [QuantizationEncoding](quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `shape` {#max.graph.weights.WeightData.shape} > shape: [Shape](type.md#max.graph.type.Shape) ### `view()` {#max.graph.weights.WeightData.view} > view(dtype) **Parameters:** **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) **Return type:** [*WeightData*](#max.graph.weights.WeightData) ## `Weights` {#max.graph.weights.Weights} > class max.graph.weights.Weights(\*args, \*\*kwargs) Helper for loading weights into a graph. A weight (max.graph.Weight) is tensors in a graph which are backed by an external buffer or mmap. Generally weights are used to avoid recompiling the graph when new weights are used (like from finetuning). For large-enough constants, it might be worth using weights for fast compilation times but the graph may be less optimized. Weight classes can be used to help with graph weight allocation and naming. This protocol defines getter methods \_\_getattr\_\_ and \_\_getitem\_\_ to assist with defining names. For example, weights.a.b\[1].c.allocate(…) creates a weight with the name “a.b.1.c”. ### `allocate()` {#max.graph.weights.Weights.allocate} > allocate(dtype=None, shape=None, quantization\_encoding=None, device=cpu:0) Creates a Weight that can be added to a graph. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) * **shape** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Dim`](type.md#max.graph.type.Dim) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `]` `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **device** ([`DeviceRef`](type.md#max.graph.type.DeviceRef) ) **Return type:** [*Weight*](Weight.md#max.graph.Weight) ### `allocated_weights` {#max.graph.weights.Weights.allocated_weights} > property allocated\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[Any](https://docs.python.org/3/library/typing.html#typing.Any), [dtype](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[\_ScalarType\_co]]] Gets the values of all weights that were allocated previously. ### `data()` {#max.graph.weights.Weights.data} > data() Returns data loaded from the weights at the current prefix. **Raises:** **KeyError if the current prefix isn't present in the checkpoint.** – **Return type:** [*WeightData*](#max.graph.weights.WeightData) ### `exists()` {#max.graph.weights.Weights.exists} > exists() Returns whether a weight with this exact name exists. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `items()` {#max.graph.weights.Weights.items} > items() Iterate through all allocable weights that start with the prefix. **Parameters:** **self** (`\_Self` ) **Return type:** [*Iterator*](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterator)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), *\_Self*]] ### `name` {#max.graph.weights.Weights.name} > property name: [str](https://docs.python.org/3/library/stdtypes.html#str) The current weight name or prefix. ### `raw_tensor()` {#max.graph.weights.Weights.raw_tensor} > raw\_tensor() Returns the numpy tensor corresponding to this weights object. **Parameters:** **dtype** – If specified, the returned array will be cast to the dtype before returning. **Raises:** **KeyError if this weights object isn't a tensor.** – **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any)]] ## `WeightsFormat` {#max.graph.weights.WeightsFormat} > class max.graph.weights.WeightsFormat(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `gguf` {#max.graph.weights.WeightsFormat.gguf} > gguf = 'gguf' ### `pytorch` {#max.graph.weights.WeightsFormat.pytorch} > pytorch = 'pytorch' ### `safetensors` {#max.graph.weights.WeightsFormat.safetensors} > safetensors = 'safetensors' ## `load_weights()` {#max.graph.weights.load_weights} > max.graph.weights.load\_weights(paths) Loads weight paths into a Weights object. **Parameters:** **paths** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) `]` ) – Local paths of weight files to load. **Returns:** A Weights object, with all of the associated weights loaded into a single object. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If an empty paths list is passed. * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If a path provided does not exist. **Return type:** [*Weights*](#max.graph.weights.Weights) ## `weights_format()` {#max.graph.weights.weights_format} > max.graph.weights.weights\_format(weight\_paths) Retrieve the format of the weights files in the provided paths. **Parameters:** **weight\_paths** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) `]` ) – A list of file paths, containing the weights for a single model. **Returns:** A WeightsFormat enum, representing whether the weights are in gguf, safetensors or pytorch format. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If weights type cannot be inferred from the paths. **Return type:** [*WeightsFormat*](#max.graph.weights.WeightsFormat) --- ## max The MAX Python API reference. The MAX API provides a state-of-the-art graph compiler and runtime library that executes AI models with incredible speed on a wide range of hardware. ## Modules * [`driver`](/max/api/python/driver): APIs to interact with devices. * [`dtype`](/max/api/python/dtype): APIs to define data types. * [`engine`](/max/api/python/engine): APIs to load and execute models. * [`entrypoints`](/max/api/python/entrypoints): APIs to run MAX models. * [`torch`](/max/api/python/torch): APIs to use custom ops with PyTorch. ## Packages * [`graph`](/max/api/python/graph): APIs to build models (inference graphs). * [`pipelines`](/max/api/python/pipelines): APIs to build model pipelines. * [`nn`](/max/api/python/nn): APIs to build MAX NN models. --- ## attention_with_rope An opaque KV Cache optimized attention mechanism with Rope. ## `AttentionWithRope` {#max.nn.attention.attention_with_rope.AttentionWithRope} > class max.nn.attention.attention\_with\_rope.AttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None) Implementation of attention that uses the rope frequency. Initializes the attention layer. **Parameters:** * **rope** ([`RotaryEmbedding`](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ) – The rope layer to borrow the freqs\_cis value from. * **num\_attention\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of attention heads. * **num\_key\_value\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of key/value heads. * **hidden\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension of the hidden states. * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) – KV Cache Params, including the number of kv heads, the head dim, and data type. * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) – DType of the QKV and output projection weights. * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` `|` `None` ) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. * **linear\_cls** (`Callable` `[` `...` `,` [`Linear`](../linear.md#max.nn.linear.Linear) `]` ) – Linear class to use for the outputs dense layer. * **stacked\_qkv** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether the weights are stacked together. * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – Value used to scale the results of the attention output. * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to use an attention bias. * **clip\_qkv** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – If provided, the QKV weights are clamped between \[-clip\_qkv, clip\_qkv] * **float8\_config** ([`Float8Config`](../linear.md#max.nn.linear.Float8Config) `|` `None` ) ### `qkv_input_scale` {#max.nn.attention.attention_with_rope.AttentionWithRope.qkv_input_scale} > property qkv\_input\_scale: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The max of q, k, and v scale input vectors. ### `qkv_weight_scale` {#max.nn.attention.attention_with_rope.AttentionWithRope.qkv_weight_scale} > property qkv\_weight\_scale: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) The max of q, k, and v scale weight vectors. ### `rope` {#max.nn.attention.attention_with_rope.AttentionWithRope.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ### `wqkv` {#max.nn.attention.attention_with_rope.AttentionWithRope.wqkv} > property wqkv: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ### `wqkv_bias` {#max.nn.attention.attention_with_rope.AttentionWithRope.wqkv_bias} > property wqkv\_bias: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The concatenation of q, k, and v bias weight vectors. ## `AttentionWithRopeQKV` {#max.nn.attention.attention_with_rope.AttentionWithRopeQKV} > class max.nn.attention.attention\_with\_rope.AttentionWithRopeQKV(n\_heads: 'int', kv\_params: 'KVCacheParams', wq: 'TensorValueLike', wk: 'TensorValueLike', wv: 'TensorValueLike', wo: 'LinearV1', scale: 'float', rope: 'RotaryEmbedding') **Parameters:** * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **wq** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **wk** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **wv** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **wo** ([`LinearV1`](../linear.md#max.nn.linear.LinearV1) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **rope** ([`RotaryEmbedding`](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ) ### `rope` {#max.nn.attention.attention_with_rope.AttentionWithRopeQKV.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ## `AttentionWithRopeV1` {#max.nn.attention.attention_with_rope.AttentionWithRopeV1} > class max.nn.attention.attention\_with\_rope.AttentionWithRopeV1(n\_heads, kv\_params, wqkv, wo, scale, rope, bias=None, perm\_idx=None, quantization\_config=None) Implementation of attention that uses the rope frequency. Deprecated: Use AttentionWithRope instead. **Parameters:** * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **wqkv** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) ) * **wo** ([`LinearV1`](../linear.md#max.nn.linear.LinearV1) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **rope** ([`RotaryEmbedding`](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ) * **bias** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) * **perm\_idx** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) * **quantization\_config** ([`QuantizationConfig`](../../graph/quantization.md#max.graph.quantization.QuantizationConfig) `|` `None` ) ### `bias` {#max.nn.attention.attention_with_rope.AttentionWithRopeV1.bias} > bias: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `perm_idx` {#max.nn.attention.attention_with_rope.AttentionWithRopeV1.perm_idx} > perm\_idx: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `quantization_config` {#max.nn.attention.attention_with_rope.AttentionWithRopeV1.quantization_config} > quantization\_config: [QuantizationConfig](../../graph/quantization.md#max.graph.quantization.QuantizationConfig) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `rope` {#max.nn.attention.attention_with_rope.AttentionWithRopeV1.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ## `DistributedAttentionWithRope` {#max.nn.attention.attention_with_rope.DistributedAttentionWithRope} > class max.nn.attention.attention\_with\_rope.DistributedAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, float8\_config=None, clip\_qkv=None) Initializes the distributed attention layer. **Parameters:** * **rope** ([`RotaryEmbedding`](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ) – The rope layer to borrow the freqs\_cis value from. * **num\_attention\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of attention heads. * **num\_key\_value\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of key/value heads. * **hidden\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension of the hidden states. * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) – KV Cache Params, including the number of kv heads, the head dim, and data type. * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` `|` `None` ) – Device to place the weights and run the computation. Must provide at least 2 devices for distributed attention. * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) – DType of the QKV and output projection weights. * **linear\_cls** (`Callable` `[` `...` `,` [`Linear`](../linear.md#max.nn.linear.Linear) `]` ) – Linear class to use for the outputs dense layer. * **stacked\_qkv** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether the weights are stacked together. * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – Value used to scale the results of the attention output. * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to use an attention bias. * **float8\_config** ([`Float8Config`](../linear.md#max.nn.linear.Float8Config) `|` `None` ) – Float8 configuration for quantization. * **clip\_qkv** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – If provided, the QKV weights are clamped between \[-clip\_qkv, clip\_qkv]. ## `GGUFQAttentionWithRope` {#max.nn.attention.attention_with_rope.GGUFQAttentionWithRope} > class max.nn.attention.attention\_with\_rope.GGUFQAttentionWithRope(\*, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, dtype, quantization\_encoding, devices=None, linear\_cls=\, scale=None, has\_bias=False, clip\_qkv=None) Implementation of attention with GGUF quantized weights. Initializes the attention layer. **Parameters:** * **rope** ([`RotaryEmbedding`](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ) – The rope layer to borrow the freqs\_cis value from. * **num\_attention\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of attention heads. * **num\_key\_value\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of key/value heads. * **hidden\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension of the hidden states. * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) – KV Cache Params, including the number of kv heads, the head dim, and data type. * **layer\_idx** – The layer number associated with this Attention block. * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) – DType of the weights, should always be uint8. * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` `|` `None` ) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. * **quantization\_encoding** ([`QuantizationEncoding`](../../graph/quantization.md#max.graph.quantization.QuantizationEncoding) ) – Quantization encoding of the weights. * **linear\_cls** (`Callable` `[` `...` `,` [`Linear`](../linear.md#max.nn.linear.Linear) `]` ) – Linear class to use for the outputs dense layer. * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – Value used to scale the results of the attention output. * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to use an attention bias. * **clip\_qkv** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – If provided, the QKV weights are clamped between \[-clip\_qkv, clip\_qkv] ### `rope` {#max.nn.attention.attention_with_rope.GGUFQAttentionWithRope.rope} > rope: [RotaryEmbedding](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ### `wqkv` {#max.nn.attention.attention_with_rope.GGUFQAttentionWithRope.wqkv} > property wqkv: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ### `wqkv_bias` {#max.nn.attention.attention_with_rope.GGUFQAttentionWithRope.wqkv_bias} > property wqkv\_bias: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) The concatenation of q, k, and v bias weight vectors. ## `GPTQAttentionWithRope` {#max.nn.attention.attention_with_rope.GPTQAttentionWithRope} > class max.nn.attention.attention\_with\_rope.GPTQAttentionWithRope(quantization\_config, rope, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, scale=None, linear\_cls=\) Implementation of the GPT-Q attention layer. Initializes the attention layer. **Parameters:** * **rope** ([`RotaryEmbedding`](../rotary_embedding.md#max.nn.rotary_embedding.RotaryEmbedding) ) – The rope layer to borrow the freqs\_cis value from. * **num\_attention\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of attention heads. * **num\_key\_value\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of key/value heads. * **hidden\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension of the hidden states. * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) – KV Cache Params, including the number of kv heads, the head dim, and data type. * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) – DType of the QKV and output projection weights. * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` `|` `None` ) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation. * **linear\_cls** (`Callable` `[` `...` `,` [`Linear`](../linear.md#max.nn.linear.Linear) `]` ) – Linear class to use for the outputs dense layer. * **stacked\_qkv** – Whether the weights are stacked together. * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – Value used to scale the results of the attention output. * **has\_bias** – Whether to use an attention bias. * **clip\_qkv** – If provided, the QKV weights are clamped between \[-clip\_qkv, clip\_qkv] * **quantization\_config** ([`QuantizationConfig`](../../graph/quantization.md#max.graph.quantization.QuantizationConfig) ) ### `wqkv` {#max.nn.attention.attention_with_rope.GPTQAttentionWithRope.wqkv} > property wqkv: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ## `distribute_value()` {#max.nn.attention.attention_with_rope.distribute_value} > max.nn.attention.attention\_with\_rope.distribute\_value(v, devices) **Parameters:** * **v** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) ) * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*TensorValue*](../../graph/TensorValue.md#max.graph.TensorValue)] --- ## attention ## Modules * [`attention_with_rope`](/max/api/python/nn/attention/attention_with_rope) * [`ragged_attention`](/max/api/python/nn/attention/ragged_attention) * [`interfaces`](/max/api/python/nn/attention/interfaces) --- ## interfaces General interface for Attention. ## `AttentionImpl` {#max.nn.attention.interfaces.AttentionImpl} > class max.nn.attention.interfaces.AttentionImpl(n\_heads, kv\_params, wqkv, wo, scale) A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention: * AttentionWithRope * AttentionWithAlibi * VanillaAttentionWithCausalMask * … There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an RotaryEmbedding class for the AttentionWithRope class: ```python @dataclass class AttentionWithRope(AttentionImpl): rope: RotaryEmbedding ... ``` We expect the `__call__` abstractmethod to remain relatively consistent, however the `**kwargs` argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask: ```python @dataclass class VanillaAttentionWithCausalMask(AttentionImpl): ... def __call__( self, x: TensorValueLike, kv_collection: ContinuousBatchingKVCacheCollection, valid_lengths: TensorValueLike, **kwargs, ) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ... if "attn_mask" not in kwargs: raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask") # Which we can then use the attention mask downstream like so: op( attn_mask = kwargs["attn_mask"] ) ``` **Parameters:** * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **wqkv** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) ) * **wo** ([`LinearV1`](../linear.md#max.nn.linear.LinearV1) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) ### `kv_params` {#max.nn.attention.interfaces.AttentionImpl.kv_params} > kv\_params: [KVCacheParams](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) KV Cache Params, including the number of kv heads, the head dim, and data type. ### `n_heads` {#max.nn.attention.interfaces.AttentionImpl.n_heads} > n\_heads: [int](https://docs.python.org/3/library/functions.html#int) The number of attention heads. ### `scale` {#max.nn.attention.interfaces.AttentionImpl.scale} > scale: [float](https://docs.python.org/3/library/functions.html#float) The scale factor for the attention. ### `wo` {#max.nn.attention.interfaces.AttentionImpl.wo} > wo: [LinearV1](../linear.md#max.nn.linear.LinearV1) A linear layer for the output projection. ### `wqkv` {#max.nn.attention.interfaces.AttentionImpl.wqkv} > wqkv: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. ## `AttentionImplQKV` {#max.nn.attention.interfaces.AttentionImplQKV} > class max.nn.attention.interfaces.AttentionImplQKV(n\_heads, kv\_params, wq, wk, wv, wo, scale) A generalized attention interface, that will be used upstream by a general Transformer. We would expect a separate subclass, articulating each variation of Attention: * AttentionWithRope * AttentionWithAlibi * VanillaAttentionWithCausalMask * … There are a series of shared attributes, however, more may be needed for each individual variant. For example, we may introduce an RotaryEmbedding class for the AttentionWithRope class: ```python @dataclass class AttentionWithRope(AttentionImpl): rope: RotaryEmbedding ... ``` We expect the `__call__` abstractmethod to remain relatively consistent, however the `**kwargs` argument is exposed, allowing you to leverage additional arguments for each particular variant. For example, we may introduce an VanillaAttentionWithCausalMask class, which includes an attention mask: ```python @dataclass class VanillaAttentionWithCausalMask(AttentionImpl): ... def __call__( self, x: TensorValueLike, kv_collection: ContinuousBatchingKVCacheCollection, valid_lengths: TensorValueLike, **kwargs, ) -> tuple[TensorValue, ContinuousBatchingKVCacheCollection]: ... if "attn_mask" not in kwargs: raise ValueError("attn_mask not provided to VanillaAttentionWithCausalMask") # Which we can then use the attention mask downstream like so: op( attn_mask = kwargs["attn_mask"] ) ``` **Parameters:** * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **wq** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **wk** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **wv** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **wo** ([`LinearV1`](../linear.md#max.nn.linear.LinearV1) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) ### `kv_params` {#max.nn.attention.interfaces.AttentionImplQKV.kv_params} > kv\_params: [KVCacheParams](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) KV Cache Params, including the number of kv heads, the head dim, and data type. ### `n_heads` {#max.nn.attention.interfaces.AttentionImplQKV.n_heads} > n\_heads: [int](https://docs.python.org/3/library/functions.html#int) The number of attention heads. ### `scale` {#max.nn.attention.interfaces.AttentionImplQKV.scale} > scale: [float](https://docs.python.org/3/library/functions.html#float) The scale factor for the attention. ### `wk` {#max.nn.attention.interfaces.AttentionImplQKV.wk} > wk: Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/type.md#max.graph.type.Shape) | [Dim](../../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) The k weight vector. ### `wo` {#max.nn.attention.interfaces.AttentionImplQKV.wo} > wo: [LinearV1](../linear.md#max.nn.linear.LinearV1) A linear layer for the output projection. ### `wq` {#max.nn.attention.interfaces.AttentionImplQKV.wq} > wq: Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/type.md#max.graph.type.Shape) | [Dim](../../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) The q weight vector. ### `wv` {#max.nn.attention.interfaces.AttentionImplQKV.wv} > wv: Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/type.md#max.graph.type.Shape) | [Dim](../../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) The v weight vector. ## `DistributedAttentionImpl` {#max.nn.attention.interfaces.DistributedAttentionImpl} > class max.nn.attention.interfaces.DistributedAttentionImpl A generalized Distributed attention interface. --- ## ragged_attention An opaque KV Cache optimized vanilla attention mechanism, with Mask Variants provided inside the Kernel. ## `RaggedAttention` {#max.nn.attention.ragged_attention.RaggedAttention} > class max.nn.attention.ragged\_attention.RaggedAttention(\*, mask\_variant, num\_attention\_heads, num\_key\_value\_heads, hidden\_size, kv\_params, devices=None, dtype=float32, linear\_cls=\, stacked\_qkv=False, scale=None, has\_bias=False, clip\_qkv=None) Layer that computes the self attention score for ragged inputs. Initializes the attention layer. **Parameters:** * **rope** – The rope layer to borrow the freqs\_cis value from. * **num\_attention\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of attention heads. * **num\_key\_value\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of key/value heads. * **hidden\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimension of the hidden states. * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) – KV Cache Params, including the number of kv heads, the head dim, and data type. * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) – DType of the * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` `|` `None` ) – Device to place the weights and run the computation. If multiple are provided, the first device is used. * **linear\_cls** (`Callable` `[` `...` `,` [`Linear`](../linear.md#max.nn.linear.Linear) `]` ) – Linear class to use for the outputs dense layer. * **stacked\_qkv** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether the weights are stacked together. * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – Value used to scale the results of the attention output. * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to use an attention bias. * **clip\_qkv** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) – If provided, the QKV weights are clamped between \[-clip\_qkv, clip\_qkv] * **mask\_variant** (`MHAMaskVariant` ) ### `wqkv` {#max.nn.attention.ragged_attention.RaggedAttention.wqkv} > property wqkv: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) The concatenation of q, k, and v weight vectors. --- ## conv The `conv` module provides classes for performing convolution operations in various dimensions (1D, 2D, and 3D) on tensor inputs. These convolution operations are core building blocks for neural networks, especially in computer vision and sequence processing tasks. Here’s an example demonstrating how to use a 1D convolution: ```python import max.nn as nn from max.graph import Graph, ops, Weight, DeviceRef from max.dtype import DType import numpy as np with Graph(name="conv_example") as graph: # Define dimensions batch_size = 2 seq_length = 10 in_channels = 16 out_channels = 32 kernel_size = 3 # Create input tensor [batch_size, sequence_length, channels] x_data = np.zeros((batch_size, seq_length, in_channels), dtype=np.float32) x = ops.constant(x_data, dtype=DType.float32, device=DeviceRef.CPU()) # Create weights for convolution filter_1d = Weight( name="filter_weight", dtype=DType.float32, shape=[kernel_size, in_channels, out_channels] device=DeviceRef.CPU() ) bias_1d = Weight( name="bias_weight", dtype=DType.float32, shape=[out_channels] device=DeviceRef.CPU() ) # Create and apply Conv1D layer conv1d = nn.Conv1D( filter=filter_1d, bias=bias_1d, stride=1, padding=1 ) output_1d = conv1d(x) print(f"Conv1D output shape: {output_1d.shape}") # Output: Conv1D output shape: [Dim(2), Dim(10), Dim(32)] ``` ## `Conv1D` {#max.nn.conv.Conv1D} > class max.nn.conv.Conv1D(kernel\_size, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None) A 1D convolution over an input signal composed of several input planes. ## Example ```python conv = nn.Conv1D( kernel_size=3, in_channels=64, out_channels=128, dtype=DType.float32, stride=1, padding=0, has_bias=False, name="conv1d_weight", device=DeviceRef.GPU(), ) ``` Initializes the Conv1D layer with weights and optional bias. **Parameters:** * **kernel\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Size of the convolving kernel. * **in\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of channels in the input signal. * **out\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of channels produced by the convolution. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type for both weights and bias. * **stride** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Stride of the convolution. Default: 1 * **padding** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Padding added to both sides of the input. Default: 0 * **dilation** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Spacing between kernel elements. Default: 1 * **num\_groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of blocked connections from input channels to output channels. Default: 1 * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) `|` `None` ) – The target device for computation. Weights remain on CPU until moved during computation. * **name** (`Union` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `None` `]` ) – Base name for weights (appended with `.weight` and `.bias` if applicable). * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – When [`True`](https://docs.python.org/3/library/constants.html#True), adds a bias vector to the layer. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). * **permute** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `bias` {#max.nn.conv.Conv1D.bias} > bias: [Weight](../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.conv.Conv1D.device) if present. ### `device` {#max.nn.conv.Conv1D.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.conv.Conv1D.dilation} > dilation: [int](https://docs.python.org/3/library/functions.html#int) Controls the dilation rate. ### `filter` {#max.nn.conv.Conv1D.filter} > filter: [Weight](../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (kernel\_size, in\_channels / num\_groups, out\_channels). Model init moves the weight to [`device`](#max.nn.conv.Conv1D.device). ### `num_groups` {#max.nn.conv.Conv1D.num_groups} > num\_groups: [int](https://docs.python.org/3/library/functions.html#int) Number of blocked connections from input channels to output channels. ### `padding` {#max.nn.conv.Conv1D.padding} > padding: [int](https://docs.python.org/3/library/functions.html#int) Controls the amount of padding applied before and after the input. ### `permute` {#max.nn.conv.Conv1D.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out\_channels, in\_channels / num\_groups, kernel\_size) Max API order: (kernel\_size, in\_channels / num\_groups, out\_channels). ### `stride` {#max.nn.conv.Conv1D.stride} > stride: [int](https://docs.python.org/3/library/functions.html#int) Controls the stride for the cross-correlation. ## `Conv1DV1` {#max.nn.conv.Conv1DV1} > class max.nn.conv.Conv1DV1(filter, bias=None, stride=1, padding=0, dilation=1, groups=1) A 1D convolution over an input signal composed of several input planes. Deprecated: Use Conv1D instead. ## Example ```python conv = nn.Conv1DV1( filter=filter_1d, bias=bias_1d, stride=1, padding=1 ) ``` **Parameters:** * **filter** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **stride** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **padding** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **dilation** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `bias` {#max.nn.conv.Conv1DV1.bias} > bias: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `dilation` {#max.nn.conv.Conv1DV1.dilation} > dilation: [int](https://docs.python.org/3/library/functions.html#int) = 1 ### `filter` {#max.nn.conv.Conv1DV1.filter} > filter: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `groups` {#max.nn.conv.Conv1DV1.groups} > groups: [int](https://docs.python.org/3/library/functions.html#int) = 1 ### `padding` {#max.nn.conv.Conv1DV1.padding} > padding: [int](https://docs.python.org/3/library/functions.html#int) = 0 ### `stride` {#max.nn.conv.Conv1DV1.stride} > stride: [int](https://docs.python.org/3/library/functions.html#int) = 1 ## `Conv2D` {#max.nn.conv.Conv2D} > class max.nn.conv.Conv2D(kernel\_size, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None) A 2D convolution over an input signal composed of several input planes. ## Example ```python conv = nn.Conv2D( kernel_size=3, in_channels=64, out_channels=128, dtype=DType.float32, stride=1, padding=0, has_bias=False, name="conv2d_weight", device=DeviceRef.GPU(), ) ``` Initializes the Conv2D layer with weights and optional bias. **Parameters:** * **kernel\_size** (`Union` `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `]` ) – Size of the convolving kernel. Can be a single int or tuple (height, width). * **in\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of channels in the input image. * **out\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of channels produced by the convolution. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type for both weights and bias. * **stride** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Stride of the convolution. Default: 1 * **padding** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Padding added to input. Can be int or tuple (pad\_top, pad\_bottom, pad\_left, pad\_right). Default: 0 * **dilation** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Spacing between kernel elements. Default: 1 * **num\_groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of blocked connections from input channels to output channels. Default: 1 * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) `|` `None` ) – The target device for computation. Weights remain on CPU until moved during computation. * **name** (`Union` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `None` `]` ) – Base name for weights (appended with `.weight` and `.bias` if applicable). * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – When [`True`](https://docs.python.org/3/library/constants.html#True), adds a bias vector to the layer. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). * **permute** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – When [`True`](https://docs.python.org/3/library/constants.html#True), permutes weights from PyTorch format to Max format. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). ### `bias` {#max.nn.conv.Conv2D.bias} > bias: [Weight](../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.conv.Conv2D.device) if present. ### `device` {#max.nn.conv.Conv2D.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.conv.Conv2D.dilation} > dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the dilation rate. ### `filter` {#max.nn.conv.Conv2D.filter} > filter: [Weight](../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (height, width, in\_channels / num\_groups, out\_channels). Model init moves the weight to [`device`](#max.nn.conv.Conv2D.device). ### `num_groups` {#max.nn.conv.Conv2D.num_groups} > num\_groups: [int](https://docs.python.org/3/library/functions.html#int) Number of blocked connections from input channels to output channels. ### `padding` {#max.nn.conv.Conv2D.padding} > padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the amount of padding applied before and after the input for height and width dimensions. ### `permute` {#max.nn.conv.Conv2D.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out\_channels, in\_channels / num\_groups, height, width) Max API order: (height, width, in\_channels / num\_groups, out\_channels). ### `shard()` {#max.nn.conv.Conv2D.shard} > shard(shard\_idx, device) Creates a sharded view of this Conv2D layer for a specific device. **Parameters:** * **shard\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The index of the shard (0 to num\_devices-1). * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) – The device where this shard should reside. **Returns:** A sharded Conv2D instance. **Return type:** [*Conv2D*](#max.nn.conv.Conv2D) ### `sharding_strategy` {#max.nn.conv.Conv2D.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the Conv2D sharding strategy. ### `stride` {#max.nn.conv.Conv2D.stride} > stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the stride for the cross-correlation. ## `Conv2DV1` {#max.nn.conv.Conv2DV1} > class max.nn.conv.Conv2DV1(filter, bias=None, stride=(1, 1), padding=(0, 0, 0, 0), dilation=(1, 1), groups=1) A 2D convolution over an input signal composed of several input planes. ## Example ```python conv = nn.Conv2DV1( filter=filter_2d, bias=bias_2d, stride=2, padding=1 ) output = conv(x) ``` **Parameters:** * **filter** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **stride** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **padding** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **dilation** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `bias` {#max.nn.conv.Conv2DV1.bias} > bias: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `dilation` {#max.nn.conv.Conv2DV1.dilation} > dilation: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] = (1, 1) ### `filter` {#max.nn.conv.Conv2DV1.filter} > filter: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `groups` {#max.nn.conv.Conv2DV1.groups} > groups: [int](https://docs.python.org/3/library/functions.html#int) = 1 ### `padding` {#max.nn.conv.Conv2DV1.padding} > padding: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] = (0, 0, 0, 0) ### `stride` {#max.nn.conv.Conv2DV1.stride} > stride: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] = (1, 1) ## `Conv3D` {#max.nn.conv.Conv3D} > class max.nn.conv.Conv3D(depth, height, width, in\_channels, out\_channels, dtype, stride=1, padding=0, dilation=1, num\_groups=1, device=None, has\_bias=False, permute=False, name=None) A 3D convolution over an input signal composed of several input planes. ## Example ```python conv = nn.Conv3D( depth=, height=, width=, in_channels=, out_channels=, dtype=DType.float32, stride=1, padding=0, has_bias=False, name="conv3d_weight", device=DeviceRef.GPU(), ) ``` Initializes the Conv3D layer with weights and optional bias. **Parameters:** * **depth** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – kernel\_size\[0] * **height** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – kernel\_size\[1] * **width** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – kernel\_size\[2] * **in\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – number of channels in the input image. * **out\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – dimensionality of the output. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type for both weights and bias. * **stride** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Stride of the convolution. Default: 1 * **padding** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Padding added to all six sides of the input. Default: 0 * **dilation** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Spacing between kernel elements. Default: 1 * **num\_groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of blocked connections from input channels to output channels. Default: 1. * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) `|` `None` ) – The target device for computation. Weights remain on CPU until moved during computation. * **name** (`Union` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `None` `]` ) – Base name for weights (appended with `.weight` and `.bias` if applicable). * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – When [`True`](https://docs.python.org/3/library/constants.html#True), adds a bias vector to the layer. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). * **permute** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `bias` {#max.nn.conv.Conv3D.bias} > bias: [Weight](../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_channels,). Model init moves the bias to [`device`](#max.nn.conv.Conv3D.device) if present. ### `device` {#max.nn.conv.Conv3D.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) | [None](https://docs.python.org/3/library/constants.html#None) The device where matrix operations are performed. ### `dilation` {#max.nn.conv.Conv3D.dilation} > dilation: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Not implemented yet. Assuming dilation = 1 for now. ### `filter` {#max.nn.conv.Conv3D.filter} > filter: [Weight](../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (depth, height, width, in\_channels / num\_groups, out\_channels). Model init moves the weight to [`device`](#max.nn.conv.Conv3D.device). ### `num_groups` {#max.nn.conv.Conv3D.num_groups} > num\_groups: [int](https://docs.python.org/3/library/functions.html#int) Not implemented yet. Assuming num\_groups = 1 for now. ### `padding` {#max.nn.conv.Conv3D.padding} > padding: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the amount of padding applied before and after the input for depth, height, and width dimensions. ### `permute` {#max.nn.conv.Conv3D.permute} > permute: [bool](https://docs.python.org/3/library/functions.html#bool) = False bool controls whether self.filter is permuted from PyTorch order to max order. PyTorch order is: (out\_channels, in\_channels / num\_groups, depth, height, width) Max API order: (depth, height, width, in\_channels / num\_groups, out\_channels). ### `stride` {#max.nn.conv.Conv3D.stride} > stride: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] Controls the stride for the cross-correlation. ## `Conv3DV1` {#max.nn.conv.Conv3DV1} > class max.nn.conv.Conv3DV1(filter, bias=None, stride=(1, 1, 1), padding=(0, 0, 0, 0, 0, 0), dilation=(1, 1, 1), groups=1) A 3D convolution over an input signal composed of several input planes. Deprecated: Use Conv3D instead. ## Example ```python conv = nn.Conv3DV1( filter=filter_3d, bias=bias_3d, stride=1, padding=1 ) ``` **Parameters:** * **filter** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **stride** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **padding** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **dilation** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `bias` {#max.nn.conv.Conv3DV1.bias} > bias: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `dilation` {#max.nn.conv.Conv3DV1.dilation} > dilation: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] = (1, 1, 1) ### `filter` {#max.nn.conv.Conv3DV1.filter} > filter: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `groups` {#max.nn.conv.Conv3DV1.groups} > groups: [int](https://docs.python.org/3/library/functions.html#int) = 1 ### `padding` {#max.nn.conv.Conv3DV1.padding} > padding: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] = (0, 0, 0, 0, 0, 0) ### `stride` {#max.nn.conv.Conv3DV1.stride} > stride: [int](https://docs.python.org/3/library/functions.html#int) | [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int), [int](https://docs.python.org/3/library/functions.html#int)] = (1, 1, 1) --- ## embedding The `embedding` module provides classes for mapping integer indices (like token IDs) to dense vector representations. These embedding operations are fundamental building blocks for natural language processing, recommendation systems, and other tasks involving discrete tokens. * `Embedding`: Basic embedding lookup table for simple use cases * `EmbeddingV2`: Enhanced embedding with device placement control and improved memory management * `VocabParallelEmbedding`: Distributed embedding that shards the vocabulary across multiple devices for large embedding tables Here’s an example demonstrating how to use embeddings: ```python import max.nn as nn from max.graph import Graph, ops, DeviceRef from max.dtype import DType import numpy as np with Graph(name="embedding_example") as graph: # Define dimensions batch_size = 4 seq_length = 16 vocab_size = 10000 hidden_dim = 256 # Create input tensor of token indices input_data = np.random.randint(0, vocab_size, (batch_size, seq_length), dtype=np.int32) input_indices = ops.constant(input_data, dtype=DType.int32, device=DeviceRef.CPU()) # Create embedding layer embedding = nn.EmbeddingV2( vocab_size=vocab_size, hidden_dim=hidden_dim, dtype=DType.float32, device=DeviceRef.GPU(), name="token_embeddings" ) # Look up embeddings for input indices embeddings = embedding(input_indices) print(f"Embedding output shape: {embeddings.shape}") # Embedding output shape: [Dim(4), Dim(16), Dim(256)] ``` ## `Embedding` {#max.nn.embedding.Embedding} > class max.nn.embedding.Embedding(vocab\_size, hidden\_dim, dtype, device, quantization\_encoding=None, name=None) A lookup table for embedding integer indices into dense vectors. This layer maps each integer index to a dense vector of fixed size. Embedding weights are stored on the CPU but are moved to the specified device during the model init phase. Example: ```python embedding_layer = Embedding( vocab_size=1000, hidden_dim=256, dtype=DType.float32, device=DeviceRef.GPU(), name="embeddings", ) token_indices: TensorValueLike embeddings = embedding_layer(token_indices) ``` Initializes the embedding layer with the given arguments. **Parameters:** * **vocab\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of unique items in the vocabulary. Indices must be in the range `[0, vocab_size)`. * **hidden\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of each embedding vector. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type of the embedding weights. * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) – The device where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device. * **name** (`Optional` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` ) – The name identifier for the embedding weight matrix. * **quantization\_encoding** (`Optional` `[` [`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `]` ) ### `device` {#max.nn.embedding.Embedding.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) The device on which embedding lookup is performed. ### `weight` {#max.nn.embedding.Embedding.weight} > weight: [Weight](../graph/Weight.md#max.graph.Weight) The embedding weight matrix stored on the CPU. Model init moves weights to the device specified in [`device`](#max.nn.embedding.Embedding.device). ## `EmbeddingV1` {#max.nn.embedding.EmbeddingV1} > class max.nn.embedding.EmbeddingV1(weights, device) A lookup table for embedding integer indices into dense vectors. Deprecated: Use Embedding instead. **Parameters:** * **weights** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) ### `device` {#max.nn.embedding.EmbeddingV1.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) ### `weights` {#max.nn.embedding.EmbeddingV1.weights} > weights: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ## `VocabParallelEmbedding` {#max.nn.embedding.VocabParallelEmbedding} > class max.nn.embedding.VocabParallelEmbedding(vocab\_size, hidden\_dim, dtype, devices, quantization\_encoding=None, name=None) A lookup table for embedding integer indices into dense vectors. This layer works like nn.Embedding except the embedding table is sharded on the vocabulary dimension across all devices. Example: ```python embedding_layer = VocabParallelEmbedding( vocab_size=1000, hidden_dim=256, dtype=DType.float32, device=[DeviceRef.GPU(0), DeviceRef.GPU(1)], name="embeddings", ) # Token indices of shape: [batch, ..., num_indices]. token_indices: TensorValueLike embeddings = embedding_layer(token_indices) ``` **Parameters:** * **vocab\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of unique items in the vocabulary. Indices must be in the range `[0, vocab_size)`. * **hidden\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of each embedding vector. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type of the embedding weights. * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) `]` ) – The devices where embedding lookups are executed. Model init transfers the initially CPU-resident weights to this device. * **name** (`Optional` `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` ) – The name identifier for the embedding weight matrix. * **quantization\_encoding** (`Optional` `[` [`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `]` ) --- ## nn APIs to build neural network components for deep learning models with Python. ## Modules * [`conv`](/max/api/python/nn/conv) * [`embedding`](/max/api/python/nn/embedding) * [`kernels`](/max/api/python/nn/kernels) * [`layer`](/max/api/python/nn/layer) * [`linear`](/max/api/python/nn/linear) * [`rotary_embedding`](/max/api/python/nn/rotary_embedding) * [`sequential`](/max/api/python/nn/sequential) ## Packages * [`attention`](/max/api/python/nn/attention) * [`norm`](/max/api/python/nn/norm) * [`transformer`](/max/api/python/nn/transformer) * [`kv_cache`](/max/api/python/nn/kv_cache) --- ## kernels Helper functions for wrapping custom kv cache/attention related ops. ## `apply_penalties_to_logits()` {#max.nn.kernels.apply_penalties_to_logits} > max.nn.kernels.apply\_penalties\_to\_logits(logits\_buffer, frequency\_data, frequency\_offsets, \*, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0) Applies penalties to the logits. **Parameters:** * **logits\_buffer** ([`BufferValue`](../graph/BufferValue.md#max.graph.BufferValue) ) – The buffer to apply penalties to. * **frequency\_data** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – 2d tensor of shape \[unique\_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token. * **frequency\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – 1d tensor of shape \[batch\_size + 1], indicating start of each sequence’s data. * **frequency\_penalty** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text: tokens will receive a penalty proportional to the count of appearances. * **presence\_penalty** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The presence penalty to apply to the model’s output A positive value will penalize new tokens that have already appeared in the generated text at least once by applying a constant penalty. * **repetition\_penalty** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in prompt and generated text at least once by dividing the logits by the repetition penalty. **Return type:** None ## `cross_attention_ragged()` {#max.nn.kernels.cross_attention_ragged} > max.nn.kernels.cross\_attention\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, kv\_input\_row\_offsets, q\_max\_seq\_len, scale, local\_window\_size=-1) Computes cross attention provided the !mo.opaque KV Cache. Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input attention, kv\_input\_row\_offsets represents the KV sequence length. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **mask\_variant** (`MHAMaskVariant` ) * **kv\_input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **q\_max\_seq\_len** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **local\_window\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `dynamic_scaled_matmul()` {#max.nn.kernels.dynamic_scaled_matmul} > max.nn.kernels.dynamic\_scaled\_matmul(a, b, a\_scales, b\_scales, out\_type=bfloat16) Perform a matmul of two tensors with scaling factors. Currently only supports channel-wise scaling for weights and per-token scaling for inputs. **Parameters:** * **a** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The first tensor to multiply. * **b** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The second tensor to multiply, must be transposed. * **a\_scales** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The scaling factors for the first tensor. * **b\_scales** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The scaling factors for the second tensor. * **out\_type** ([`DType`](../dtype.md#max.dtype.DType) ) **Returns:** The result of the matmul operation. **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `flare_mla_decode_ragged()` {#max.nn.kernels.flare_mla_decode_ragged} > max.nn.kernels.flare\_mla\_decode\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, qk\_rope\_dim=64) Computes flash (self) attention provided the !mo.opaque KV Cache. Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross\_attention\_ragged. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** (`PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **mask\_variant** (`MHAMaskVariant` ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **qk\_rope\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `flare_mla_decompress_k_cache()` {#max.nn.kernels.flare_mla_decompress_k_cache} > max.nn.kernels.flare\_mla\_decompress\_k\_cache(kv\_params, buffer\_row\_offsets\_1d, cache\_offsets\_1d, buffer\_length, weight, kv\_collection, layer\_idx, buffer\_size) This kernel decompresses the key cache by up-projecting latent representations into the KV space using a weight matrix. The process involves: : 1. Copying buffer\_length latent vectors from the key cache into a contiguous buffer (k\_latent) 2\. Computing k = k\_latent @ weight.T to obtain the decompressed keys **Returns:** A tensor of shape \[buffer\_size, weight.shape\[0]] containing the decompressed keys. Note that only the first buffer\_length tokens are valid. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **buffer\_row\_offsets\_1d** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **cache\_offsets\_1d** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **buffer\_length** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** (`PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **buffer\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `flare_mla_prefill_plan()` {#max.nn.kernels.flare_mla_prefill_plan} > max.nn.kernels.flare\_mla\_prefill\_plan(kv\_params, input\_row\_offsets, kv\_collection, layer\_idx, buffer\_size, max\_chunks=16) This kernel plans how to process a batch of sequences with varying lengths using a fixed-size buffer. Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer\_size. For each chunk (iteration), it calculates: : 1. Buffer offsets for each sequence in each chunk 2\. Cache offsets for each sequence in each chunk 3\. Total buffer lengths for each processing iteration **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** (`PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **buffer\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_chunks** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue), [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue), [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue)] ## `flare_mla_prefill_ragged()` {#max.nn.kernels.flare_mla_prefill_ragged} > max.nn.kernels.flare\_mla\_prefill\_ragged(kv\_params, input, k, v, input\_row\_offsets, buffer\_row\_offsets, cache\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, qk\_rope\_dim=64, prev\_output=None, prev\_softmax\_info=None) Performs MLA prefill. In the MLA prefill, we need to decompress the KV tensors, as we store the latent representations in the KV cache. We will decompress the KV tensors into a fixed size buffer to avoid out-of-memory errors. In case the total cache length is greater than the buffer size, we will process the attention calculation in chunks. This MLA prefill kernel will return the output tensor for this iteration and the softmax info tensor for this iteration. Such tensors will be used by the next iteration of the MLA prefill kernel to continue the attention calculation. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) – KVCacheParams * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Input tensor * **k** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Key tensor * **v** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Value tensor * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Indicates where each batch starts and ends in input * **buffer\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Indicates where each batch starts and ends in the buffer * **cache\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Indicates where each batch starts and ends in the KV cache * **kv\_collection** (`PagedKVCacheCollection` ) – KV collection * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Layer index tensor * **mask\_variant** (`MHAMaskVariant` ) – Mask variant * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – Scale * **qk\_rope\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – QK rope dimension * **prev\_output** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) – Optional. Previous output tensor * **prev\_softmax\_info** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) – Optional. Previous softmax info tensor **Returns:** * The first tensor is the output tensor for this iteration * The second tensor is the softmax info tensor for this iteration **Return type:** A tuple of two tensors ## `flash_attention()` {#max.nn.kernels.flash_attention} > max.nn.kernels.flash\_attention(kv\_params, input, kv\_collection, layer\_idx, attention\_mask, valid\_lengths, scale) Computes flash attention provided the mo.opaque KV Cache. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **attention\_mask** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **valid\_lengths** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `flash_attention_gpu()` {#max.nn.kernels.flash_attention_gpu} > max.nn.kernels.flash\_attention\_gpu(q, k, v, mask\_variant, scale, local\_window\_size=-1, valid\_length=None) Computes flash attention using GPU-optimized kernel. **Parameters:** * **q** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Query tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * **k** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Key tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * **v** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Value tensor of shape \[batch, seq\_len, num\_heads, head\_dim] * **mask\_variant** (`MHAMaskVariant` ) – The mask variant to use for attention * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – Scaling factor for attention scores * **local\_window\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Local window size for sliding window attention * **valid\_length** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) – Optional tensor of shape \[batch] with dtype uint32. When provided, uses the padded kernel variant that respects the valid sequence lengths for each batch element. **Returns:** Output tensor of shape \[batch, seq\_len, num\_heads, head\_dim] **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `flash_attention_ragged()` {#max.nn.kernels.flash_attention_ragged} > max.nn.kernels.flash\_attention\_ragged(kv\_params, input, input\_row\_offsets, kv\_collection, layer\_idx, mask\_variant, scale, local\_window\_size=-1) Computes flash (self) attention provided the !mo.opaque KV Cache. Notably, this materializes the attention mask (dependent on MHAMaskVariant) within the kernel. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input Note that this is self attention and the KV sequence length is assumed to be equal to the Q sequence length. For KV sequence length != Q sequence length, use cross\_attention\_ragged. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **mask\_variant** (`MHAMaskVariant` ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **local\_window\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `flash_attention_with_causal_mask()` {#max.nn.kernels.flash_attention_with_causal_mask} > max.nn.kernels.flash\_attention\_with\_causal\_mask(kv\_params, input, kv\_collection, layer\_idx, valid\_lengths, scale) Computes flash attention provided the mo.opaque KV Cache. Notably, materializes the causal mask within the kernel. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **valid\_lengths** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **scale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `fused_qk_ragged_rope()` {#max.nn.kernels.fused_qk_ragged_rope} > max.nn.kernels.fused\_qk\_ragged\_rope(kv\_params, input, input\_row\_offsets, kv\_collection, freqs\_cis, layer\_idx, interleaved=True) Computes fused query-key attention with rotary positional encodings and ragged inputs. **Parameters:** * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – \[batch\_size \* seq\_len, n\_heads, head\_dim] * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **freqs\_cis** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – tensor of shape (max\_seq\_len \* 2, head\_dim) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **interleaved** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input ## `fused_qk_rope()` {#max.nn.kernels.fused_qk_rope} > max.nn.kernels.fused\_qk\_rope(kv\_params, input, kv\_collection, freqs\_cis\_2d, layer\_idx, interleaved=True) Computes fused query-key attention with rotary positional encodings. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) ) * **freqs\_cis\_2d** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **interleaved** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `fused_qkv_matmul()` {#max.nn.kernels.fused_qkv_matmul} > max.nn.kernels.fused\_qkv\_matmul(kv\_params, input, wqkv, kv\_collection, layer\_idx, n\_heads) Computes fused query, key and value projections. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **wqkv** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `fused_qkv_ragged_matmul()` {#max.nn.kernels.fused_qkv_ragged_matmul} > max.nn.kernels.fused\_qkv\_ragged\_matmul(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, bias=None) Computes fused query, key, and value projections with ragged input. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **wqkv** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **bias** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `fused_qkv_ragged_matmul_quantized()` {#max.nn.kernels.fused_qkv_ragged_matmul_quantized} > max.nn.kernels.fused\_qkv\_ragged\_matmul\_quantized(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, quantization\_config, perm\_idx=None, bias=None) Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization\_config must be provided. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **wqkv** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **quantization\_config** ([`QuantizationConfig`](../graph/quantization.md#max.graph.quantization.QuantizationConfig) ) * **perm\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) * **bias** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `fused_qkv_ragged_matmul_scaled_float8()` {#max.nn.kernels.fused_qkv_ragged_matmul_scaled_float8} > max.nn.kernels.fused\_qkv\_ragged\_matmul\_scaled\_float8(kv\_params, input, input\_row\_offsets, wqkv, kv\_collection, layer\_idx, n\_heads, input\_scale, weight\_scale, bias=None) Computes fused query, key, and value projections with ragged input. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **wqkv** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** (`PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **input\_scale** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight\_scale** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **bias** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `grouped_matmul_ragged()` {#max.nn.kernels.grouped_matmul_ragged} > max.nn.kernels.grouped\_matmul\_ragged(hidden\_states, weight, expert\_start\_indices, expert\_ids, expert\_usage\_stats\_host) Grouped matmul used in MoE layer. hidden\_states and expert\_start\_indices are used together to implement the ragged tensor. expert\_start\_indices indicates where each group starts and ends in hidden\_states expert\_ids is the id of the expert for each group in hidden\_states expert\_usage\_stats\_host is the maximum number of tokens assigned to any expert, and the number of active experts. **Parameters:** * **hidden\_states** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **expert\_start\_indices** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **expert\_ids** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **expert\_usage\_stats\_host** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `kv_cache_get_max_seq_len()` {#max.nn.kernels.kv_cache_get_max_seq_len} > max.nn.kernels.kv\_cache\_get\_max\_seq\_len(kv\_collection) This kernel returns the maximum sequence length. **Parameters:** **kv\_collection** (`PagedKVCacheCollection` ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `matmul_k_cache_ragged()` {#max.nn.kernels.matmul_k_cache_ragged} > max.nn.kernels.matmul\_k\_cache\_ragged(kv\_params, hidden\_states, input\_row\_offsets, weight, kv\_collection, layer\_idx) Computes key projections with ragged input. hidden\_states and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **hidden\_states** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** (`PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) **Return type:** None ## `matmul_kv_cache_ragged()` {#max.nn.kernels.matmul_kv_cache_ragged} > max.nn.kernels.matmul\_kv\_cache\_ragged(kv\_params, hidden\_states, input\_row\_offsets, weight, kv\_collection, layer\_idx) Computes key and value projections with ragged input. hidden\_states and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **hidden\_states** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **kv\_collection** (`PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) **Return type:** None ## `matmul_static_scaled_float8()` {#max.nn.kernels.matmul_static_scaled_float8} > max.nn.kernels.matmul\_static\_scaled\_float8(input, weight, input\_scale, weight\_scale) **Parameters:** * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_scale** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight\_scale** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `merge_ragged_tensors()` {#max.nn.kernels.merge_ragged_tensors} > max.nn.kernels.merge\_ragged\_tensors(a, a\_row\_offsets, b, b\_row\_offsets) Merges two ragged tensors into a single ragged tensor. Both ragged tensors must have the same batch size (same number of row offsets). This function interleaves the rows from each tensor based on their row offsets. **Parameters:** * **a** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The first ragged tensor of shape \[total\_a\_rows, …]. * **a\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The row offsets of the first ragged tensor,indicating where each batch starts and ends in a. * **b** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The second ragged tensor of shape \[total\_b\_rows, …]. * **b\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The row offsets of the second ragged tensor, indicating where each batch starts and ends in b. **Returns:** * The merged ragged tensor with shape \[total\_a\_rows + total\_b\_rows, …]. * The merged row offsets with the same shape as input row offsets. **Return type:** A tuple of two tensors ## Example a = [1, 2, 3, 4, 5, 6] a\_row\_offsets = [0, 2, 6] b = [7, 8, 9, 10] b\_row\_offsets = [0, 3, 4] merged\_tensor, merged\_row\_offsets = merge\_ragged\_tensors( : a, a\_row\_offsets, b, b\_row\_offsets) merged\_tensor = [1, 2, 7, 8, 9, 3, 4, 5, 6, 10] merged\_row\_offsets = [0, 5, 10] ## `moe_create_indices()` {#max.nn.kernels.moe_create_indices} > max.nn.kernels.moe\_create\_indices(topk\_ids, num\_local\_experts) Creates indices for the MoE layer. **Parameters:** * **topk\_ids** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The expert assignments for each token from the router. * **num\_local\_experts** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of experts on this device. **Returns:** * token\_expert\_order: The reordered token indices, grouped by assigned expert. * expert\_start\_indices: The starting index for each expert’s token group in the reordered sequence. * restore\_token\_order: The indices to restore original token ordering after expert computation. * expert\_ids: ids of active experts selected for tokens * expert\_usage\_stats: The maximum number of tokens assigned to any expert, and the number of active experts. **Return type:** A tuple of four tensors ## `quantize_dynamic_scaled_float8()` {#max.nn.kernels.quantize_dynamic_scaled_float8} > max.nn.kernels.quantize\_dynamic\_scaled\_float8(input, scale\_ub=1200.0, group\_size\_or\_per\_token=-1, out\_type=float8\_e4m3fn, scales\_type=bfloat16) Dynamically quantize the input tensor to fp8. **Parameters:** * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The input tensor to quantize. * **scale\_ub** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – The upper bound of the scale factor. * **group\_size\_or\_per\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The group size for quantization. When set to -1, the quantization is column-wise. * **out\_type** ([`DType`](../dtype.md#max.dtype.DType) ) – The type of the output tensor. * **scales\_type** ([`DType`](../dtype.md#max.dtype.DType) ) – The type of the scales tensor. **Returns:** The quantized tensor and the scales. **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue), [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue)] ## `quantize_static_scaled_float8()` {#max.nn.kernels.quantize_static_scaled_float8} > max.nn.kernels.quantize\_static\_scaled\_float8(x, scale, scale\_is\_inverted=True) **Parameters:** * **x** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **scale** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **scale\_is\_inverted** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `rms_norm_key_cache()` {#max.nn.kernels.rms_norm_key_cache} > max.nn.kernels.rms\_norm\_key\_cache(kv\_params, kv\_collection, gamma, epsilon, layer\_idx, total\_seq\_len, input\_row\_offsets, weight\_offset, rms\_norm\_cols=None, multiply\_before\_cast=True, per\_head\_norm=True) This function applies RMSNorm to the \_new\_ entries in the KVCache. When per\_head\_norm=True (default), RMSNorm is applied separately to each head. In this mode, gamma should have size \[head\_dim] and normalization occurs across the head\_dim dimensions within each head. When per\_head\_norm=False, RMSNorm is applied per token across all heads. In this mode, gamma should have size \[n\_kv\_heads \* head\_dim] and normalization occurs across all dimensions for each token. The size of the gamma tensor determines how many dimensions will be normalized. If gamma’s size doesn’t match the expected size based on per\_head\_norm setting, rms\_norm\_cols must be explicitly specified to confirm the intention to normalize only a subset of dimensions. Currently, the KVCacheT class itself isn’t aware of the new cache entries until cache length increment, which happens after model forward. So use input\_row\_offsets to do this bookkeeping. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) * **gamma** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **epsilon** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **total\_seq\_len** ([`Dim`](../graph/type.md#max.graph.type.Dim) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **weight\_offset** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) ) * **rms\_norm\_cols** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **multiply\_before\_cast** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **per\_head\_norm** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** None ## `scatter_set_constant()` {#max.nn.kernels.scatter_set_constant} > max.nn.kernels.scatter\_set\_constant(data, indices, fill\_val) Scatters values into a tensor at specified indices. **Parameters:** * **data** ([`BufferValue`](../graph/BufferValue.md#max.graph.BufferValue) ) * **indices** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **fill\_val** ([`float`](https://docs.python.org/3/library/functions.html#float) ) **Return type:** None ## `swish_glu()` {#max.nn.kernels.swish_glu} > max.nn.kernels.swish\_glu(a, b0, b1) Computes swish(.t()) \* (.t()) **Parameters:** * **a** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **b0** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **b1** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `topk_fused_sampling()` {#max.nn.kernels.topk_fused_sampling} > max.nn.kernels.topk\_fused\_sampling(logits, top\_k, \*, temperature=1.0, max\_k=None, top\_p=1.0, seed=0) Performs top-k sampling with temperature scaling. **Parameters:** * **logits** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – Input logits tensor of shape \[batch\_size, vocab\_size]. * **top\_k** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Number of top tokens to consider for sampling. Can be a scalar (which will be expanded to batch\_size) or a tensor of shape \[batch\_size]. * **temperature** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Temperature for scaling logits before sampling. * **max\_k** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) – Maximum value of k across the batch. Required when top\_k is a tensor. * **top\_p** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Top-p (nucleus) sampling threshold. Can be a scalar or tensor. * **seed** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Seed for the random number generator. Can be a scalar or tensor. **Returns:** Sampled tokens tensor of shape \[batch\_size, 1]. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If input validation fails. **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `unfused_qkv_ragged_matmul_gguf_quantized()` {#max.nn.kernels.unfused_qkv_ragged_matmul_gguf_quantized} > max.nn.kernels.unfused\_qkv\_ragged\_matmul\_gguf\_quantized(kv\_params, input, input\_row\_offsets, n\_heads, q\_weight, k\_weight, v\_weight, quantization\_encoding\_q, quantization\_encoding\_k, quantization\_encoding\_v, kv\_collection, layer\_idx) Computes fused query, key, and value projections with ragged input and quantized weight matrices. A quantization\_config must be provided. input and input\_row\_offsets are used together to implement the ragged tensor. input\_row\_offsets indicates where each batch starts and ends in input **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – on input shapes/dtypes that are invalid for the kernel. **Parameters:** * **kv\_params** ([`KVCacheParams`](kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **input** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **input\_row\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **q\_weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **k\_weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **v\_weight** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) * **quantization\_encoding\_q** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) ) * **quantization\_encoding\_k** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) ) * **quantization\_encoding\_v** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) ) * **kv\_collection** ([`ContinuousBatchingKVCacheCollection`](kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection) `|` `PagedKVCacheCollection` ) * **layer\_idx** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ## `update_frequency_data()` {#max.nn.kernels.update_frequency_data} > max.nn.kernels.update\_frequency\_data(frequency\_data, frequency\_offsets, tokens) Updates the frequency data. **Parameters:** * **frequency\_data** ([`BufferValue`](../graph/BufferValue.md#max.graph.BufferValue) ) – 2d tensor of shape \[unique\_tokens, 2], where the first column indicates the token id and the second column indicates the frequency of the token. * **frequency\_offsets** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – 1d tensor of shape \[batch\_size + 1], indicating start of each sequence’s data. * **tokens** ([`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) ) – The tokens to update the frequency data with. **Return type:** None --- ## cache_params ## `KVCacheParams` {#max.nn.kv_cache.cache_params.KVCacheParams} > class max.nn.kv\_cache.cache\_params.KVCacheParams(dtype: max.\_core.dtype.DType, n\_kv\_heads: int, head\_dim: int, enable\_prefix\_caching: bool = False, enable\_kvcache\_swapping\_to\_host: bool = False, host\_kvcache\_swap\_space\_gb: Optional\[float] = None, cache\_strategy: max.nn.kv_cache.cache_params.KVCacheStrategy = \, page\_size: Optional\[int] = None, n\_devices: int = 1) **Parameters:** * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) * **n\_kv\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **head\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **enable\_prefix\_caching** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **enable\_kvcache\_swapping\_to\_host** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **host\_kvcache\_swap\_space\_gb** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) * **cache\_strategy** ([`KVCacheStrategy`](#max.nn.kv_cache.cache_params.KVCacheStrategy) ) * **page\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **n\_devices** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `cache_strategy` {#max.nn.kv_cache.cache_params.KVCacheParams.cache_strategy} > cache\_strategy: [KVCacheStrategy](#max.nn.kv_cache.cache_params.KVCacheStrategy) = 'continuous' ### `dtype` {#max.nn.kv_cache.cache_params.KVCacheParams.dtype} > dtype: [DType](../../dtype.md#max.dtype.DType) ### `dtype_shorthand` {#max.nn.kv_cache.cache_params.KVCacheParams.dtype_shorthand} > property dtype\_shorthand: [str](https://docs.python.org/3/library/stdtypes.html#str) The textual representation in shorthand of the dtype. ### `enable_kvcache_swapping_to_host` {#max.nn.kv_cache.cache_params.KVCacheParams.enable_kvcache_swapping_to_host} > enable\_kvcache\_swapping\_to\_host: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `enable_prefix_caching` {#max.nn.kv_cache.cache_params.KVCacheParams.enable_prefix_caching} > enable\_prefix\_caching: [bool](https://docs.python.org/3/library/functions.html#bool) = False ### `head_dim` {#max.nn.kv_cache.cache_params.KVCacheParams.head_dim} > head\_dim: [int](https://docs.python.org/3/library/functions.html#int) ### `host_kvcache_swap_space_gb` {#max.nn.kv_cache.cache_params.KVCacheParams.host_kvcache_swap_space_gb} > host\_kvcache\_swap\_space\_gb: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `n_devices` {#max.nn.kv_cache.cache_params.KVCacheParams.n_devices} > n\_devices: [int](https://docs.python.org/3/library/functions.html#int) = 1 ### `n_kv_heads` {#max.nn.kv_cache.cache_params.KVCacheParams.n_kv_heads} > n\_kv\_heads: [int](https://docs.python.org/3/library/functions.html#int) ### `page_size` {#max.nn.kv_cache.cache_params.KVCacheParams.page_size} > page\_size: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `static_cache_shape` {#max.nn.kv_cache.cache_params.KVCacheParams.static_cache_shape} > property static\_cache\_shape: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str)] ## `KVCacheStrategy` {#max.nn.kv\_cache.cache\_params.KVCacheStrategy} > class max.nn.kv\_cache.cache\_params.KVCacheStrategy(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `CONTINUOUS` {#max.nn.kv_cache.cache_params.KVCacheStrategy.CONTINUOUS} > CONTINUOUS = 'continuous' ### `MODEL_DEFAULT` {#max.nn.kv_cache.cache_params.KVCacheStrategy.MODEL_DEFAULT} > MODEL\_DEFAULT = 'model\_default' ### `PAGED` {#max.nn.kv_cache.cache_params.KVCacheStrategy.PAGED} > PAGED = 'paged' ### `kernel_substring()` {#max.nn.kv_cache.cache_params.KVCacheStrategy.kernel_substring} > kernel\_substring() Returns the common substring that we include in the kernel name for this caching strategy. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `uses_opaque()` {#max.nn.kv_cache.cache_params.KVCacheStrategy.uses_opaque} > uses\_opaque() **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) --- ## continuous_batching_cache Continuous Batching enabled KV cache for the Transformer leveraging the mo.opaque pattern. ## `ContinuousBatchingKVCache` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCache} > class max.nn.kv\_cache.continuous\_batching\_cache.ContinuousBatchingKVCache(value) Continuous Mojo KV cache graph value. Value is abstract, it shouldn’t be constructed directly. **Parameters:** **value** ([`Value`](../../graph/Value.md#max.graph.Value) `|` `\_Value` `[` `mo.OpaqueType` `]` ) ## `ContinuousBatchingKVCacheCollection` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollection} > class max.nn.kv\_cache.continuous\_batching\_cache.ContinuousBatchingKVCacheCollection(value) The graph value for a view of the KV cache. Value is abstract, it shouldn’t be constructed directly. **Parameters:** **value** ([`Value`](../../graph/Value.md#max.graph.Value) `|` `\_Value` `[` `mo.OpaqueType` `]` ) ## `ContinuousBatchingKVCacheCollectionType` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheCollectionType} > class max.nn.kv\_cache.continuous\_batching\_cache.ContinuousBatchingKVCacheCollectionType The graph type for a “view” of the cache for the given sequences in the batch. This object does not own the underlying buffers in k\_cache and v\_cache, it’s borrowing them from the BlockWrappers in our ContinuousKVCacheManager. It does own the Pointer\[NDBuffer\[type, 3]] and valid\_lengths buffer Creates an opaque type containing a continuous batching KV cache collection. ## `ContinuousBatchingKVCacheInputSymbols` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheInputSymbols} > class max.nn.kv\_cache.continuous\_batching\_cache.ContinuousBatchingKVCacheInputSymbols(kv\_blocks: 'TensorType', cache\_lengths: 'TensorType', lookup\_table: 'TensorType', max\_lengths: 'TensorType') **Parameters:** * **kv\_blocks** ([`TensorType`](../../graph/type.md#max.graph.type.TensorType) ) * **cache\_lengths** ([`TensorType`](../../graph/type.md#max.graph.type.TensorType) ) * **lookup\_table** ([`TensorType`](../../graph/type.md#max.graph.type.TensorType) ) * **max\_lengths** ([`TensorType`](../../graph/type.md#max.graph.type.TensorType) ) ### `cache_lengths` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheInputSymbols.cache_lengths} > cache\_lengths: [TensorType](../../graph/type.md#max.graph.type.TensorType) ### `kv_blocks` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheInputSymbols.kv_blocks} > kv\_blocks: [TensorType](../../graph/type.md#max.graph.type.TensorType) ### `lookup_table` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheInputSymbols.lookup_table} > lookup\_table: [TensorType](../../graph/type.md#max.graph.type.TensorType) ### `max_lengths` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheInputSymbols.max_lengths} > max\_lengths: [TensorType](../../graph/type.md#max.graph.type.TensorType) ## `ContinuousBatchingKVCacheManager` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager} > class max.nn.kv\_cache.continuous\_batching\_cache.ContinuousBatchingKVCacheManager(params, max\_batch\_size, max\_seq\_len, num\_layers, devices, session) **Parameters:** * **params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **max\_batch\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **num\_layers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** (`Sequence` `[` [`Device`](../../driver.md#max.driver.Device) `]` ) * **session** ([`InferenceSession`](../../engine.md#max.engine.InferenceSession) ) ### `block_shape()` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager.block_shape} > block\_shape(n\_sequences) Returns the shape of the KV cache blocks for the given number of sequences. Defines the 6-dimensional shape of the cache blocks used to store key and value tensors for transformer attention. The dimensions represent: \[n\_sequences, 2, num\_layers, max\_seq\_len, n\_kv\_heads\_per\_device, head\_dim] where 2 represents separate storage for keys and values. **Parameters:** **n\_sequences** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of sequences that will be cached **Returns:** sequences, key/value split, layers, sequence length, attention heads, and head dimension **Return type:** List describing the shape of the cache blocks with dimensions for ### `estimated_memory_size()` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager.estimated_memory_size} > classmethod estimated\_memory\_size(params, max\_batch\_size, max\_seq\_len, num\_layers, available\_cache\_memory, devices, \*\*kwargs) Returns the estimated total memory usage of the kv cache. **Parameters:** * **params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **max\_batch\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **num\_layers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Device`](../../driver.md#max.driver.Device) `]` ) * **kwargs** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `fetch()` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager.fetch} > fetch(batch, num\_steps=1) Fetches the KV cache state for the given sequence IDs. This method retrieves the current cache state for a batch of sequences, including their cache lengths and lookup information. It’s used during token generation to access previously cached key/value pairs. **Parameters:** * **batch** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `T` `]` ) – List of KVCacheAwareContext for which to fetch cache state for. * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of steps to run for multi-step scheduling. **Returns:** * blocks: Tensor containing the KV cache blocks * cache\_lengths: Tensor of current cache lengths for each sequence * lookup\_table: Tensor mapping sequence IDs to cache positions * max\_lengths: Tensor containing \[max\_seq\_length, max\_cache\_length] **Return type:** List of tuples for each device containing **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If any seq\_id exceeds max\_batch\_size or doesn’t exist in cache ### `infer_optimal_batch_size()` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager.infer_optimal_batch_size} > classmethod infer\_optimal\_batch\_size(params, max\_seq\_len, num\_layers, available\_cache\_memory, devices, \*\*kwargs) Returns the estimated optimal batch size for the kv cache. **Parameters:** * **params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **num\_layers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Device`](../../driver.md#max.driver.Device) `]` ) * **kwargs** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `input_symbols()` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheManager.input_symbols} > input\_symbols() Returns the expected input tensor types for fetch on each device. Defines the tensor specifications needed by the cache implementation, including shapes and data types. This is used for graph construction and validation. **Returns:** List of ContinuousBatchingKVCacheInputSymbols for each device containing TensorTypes for: * KV cache blocks: 6D tensor for storing keys and values * Cache lengths: 1D tensor tracking sequence lengths * Lookup table: 1D tensor mapping sequence IDs to cache positions * Maximum lengths: 2D tensor tracking maximum sequence and cache lengths per step. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*ContinuousBatchingKVCacheInputSymbols*](#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheInputSymbols)] ## `ContinuousBatchingKVCacheType` {#max.nn.kv_cache.continuous_batching_cache.ContinuousBatchingKVCacheType} > class max.nn.kv\_cache.continuous\_batching\_cache.ContinuousBatchingKVCacheType Continuous Mojo KV Cache graph type. Creates an opaque type containing a continuous batching KV Cache. ## `FetchContinuousBatchingKVCacheCollection` {#max.nn.kv_cache.continuous_batching_cache.FetchContinuousBatchingKVCacheCollection} > class max.nn.kv\_cache.continuous\_batching\_cache.FetchContinuousBatchingKVCacheCollection(kv\_params, \*\*kwargs) **Parameters:** * **kv\_params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **kwargs** (`Any` ) --- ## hf ## `ContinuousHFStaticCache` {#max.nn.kv_cache.hf.ContinuousHFStaticCache} > class max.nn.kv\_cache.hf.ContinuousHFStaticCache(config, max\_batch\_size, max\_seq\_len, device, dtype=torch.float32, layer\_device\_map=None) **Parameters:** * **config** (`PretrainedConfig` ) * **max\_batch\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **device** (`device` ) * **dtype** (`dtype` ) * **layer\_device\_map** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `device` `|` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` `None` ) ### `external_claim()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.external_claim} > external\_claim(seq\_ids) **Parameters:** **seq\_ids** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) **Return type:** None ### `get_attention_mask()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.get_attention_mask} > get\_attention\_mask(seq\_ids) **Parameters:** **seq\_ids** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) **Return type:** *Tensor* ### `release()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.release} > release(seq\_id) **Parameters:** **seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `reset()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.reset} > reset() Resets the cache values while preserving the objects **Return type:** None ### `set_active_slots()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.set_active_slots} > set\_active\_slots(seq\_ids) **Parameters:** **seq\_ids** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) **Return type:** None ### `set_cache_position()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.set_cache_position} > set\_cache\_position(cache\_position) **Parameters:** **cache\_position** (`Tensor` ) ### `update()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.update} > update(key\_states, value\_states, layer\_idx, cache\_kwargs=None) Updates the cache with the new key\_states and value\_states for the layer layer\_idx. It is VERY important to index using a tensor, otherwise you introduce a copy to the device. **Parameters:** * **key\_states** (torch.Tensor) – The new key states to cache. * **value\_states** (torch.Tensor) – The new value states to cache. * **layer\_idx** (int) – The index of the layer to cache the states for. * **cache\_kwargs** (Dict\[str, Any], optional) – Additional arguments for the cache subclass. The StaticCache needs the cache\_position input to know how where to write in the cache. **Returns:** A tuple containing the updated key and value states. **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[*Tensor*, *Tensor*] ### `update_attention_pattern()` {#max.nn.kv_cache.hf.ContinuousHFStaticCache.update_attention_pattern} > update\_attention\_pattern(seq\_id, attention\_mask) **Parameters:** * **seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **attention\_mask** (`Tensor` ) **Return type:** None --- ## kv_cache ## Modules * [`cache_params`](/max/api/python/nn/kv_cache/cache_params) * [`continuous_batching_cache`](/max/api/python/nn/kv_cache/continuous_batching_cache) * [`hf`](/max/api/python/nn/kv_cache/hf) * [`manager`](/max/api/python/nn/kv_cache/manager) --- ## manager Abstract base class for KVCacheManager for KV Cache. ## `KVCacheInputSymbols` {#max.nn.kv_cache.manager.KVCacheInputSymbols} > class max.nn.kv\_cache.manager.KVCacheInputSymbols Base class for input symbols for KV cache managers. The derived class is responsible for defining the input symbols for the specific KV cache manager. For example, here’s a derived class for a text KV cache manager: ```python @dataclass class ContinuousBatchingKVCacheInputSymbols(KVCacheInputSymbols): kv_blocks: TensorType cache_lengths: TensorType lookup_table: TensorType max_lengths: TensorType ``` ## `KVCacheInputs` {#max.nn.kv_cache.manager.KVCacheInputs} > class max.nn.kv\_cache.manager.KVCacheInputs A base class that holds KV cache related (Tensor) inputs. It is meant to be subclassed by concrete KV cache input types. For example, here’s a derived class for a text KV cache manager: ```python @dataclass class RaggedKVCacheInputs(KVCacheInputs): blocks: Tensor cache_lengths: Tensor lookup_table: Tensor max_lengths: Tensor ``` ## `KVCacheInputsSequence` {#max.nn.kv_cache.manager.KVCacheInputsSequence} > class max.nn.kv\_cache.manager.KVCacheInputsSequence(kv\_cache\_inputs) `KVCacheInputsSequence` is a sequence of [`KVCacheInputs`](#max.nn.kv_cache.manager.KVCacheInputs). It is primarily used in our multistep execution to represent batched KVCacheInputs. **Parameters:** **kv\_cache\_inputs** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`KVCacheInputs`](#max.nn.kv_cache.manager.KVCacheInputs) `]` ) ### `kv_cache_inputs` {#max.nn.kv_cache.manager.KVCacheInputsSequence.kv_cache_inputs} > kv\_cache\_inputs: [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[KVCacheInputs](#max.nn.kv_cache.manager.KVCacheInputs)] ## `KVCacheManager` {#max.nn.kv_cache.manager.KVCacheManager} > class max.nn.kv\_cache.manager.KVCacheManager(params, max\_batch\_size, max\_seq\_len, num\_layers, devices, session, is\_ragged=False) **Parameters:** * **params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **max\_batch\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **num\_layers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** (`Sequence` `[` [`Device`](../../driver.md#max.driver.Device) `]` ) * **session** ([`InferenceSession`](../../engine.md#max.engine.InferenceSession) ) * **is\_ragged** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `claim()` {#max.nn.kv_cache.manager.KVCacheManager.claim} > claim(n) Claims `n` blocks of memory in the cache for incoming requests. This returns a list of sequence ids, which identify a sequence’s location within the cache. This sequence id can then be passed in the fetch function to return the `ContinuousBatchingKVCacheCollection` for those sequences. **Parameters:** **n** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] ### `contains()` {#max.nn.kv_cache.manager.KVCacheManager.contains} > contains(seq\_id) **Parameters:** **seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `estimated_memory_size()` {#max.nn.kv_cache.manager.KVCacheManager.estimated_memory_size} > abstract classmethod estimated\_memory\_size(params, max\_batch\_size, max\_seq\_len, num\_layers, available\_cache\_memory, devices, \*\*kwargs) Returns the estimated total memory usage of the kv cache. **Parameters:** * **params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **max\_batch\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **num\_layers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Device`](../../driver.md#max.driver.Device) `]` ) * **kwargs** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `external_claim()` {#max.nn.kv_cache.manager.KVCacheManager.external_claim} > external\_claim(seq\_ids) Variant of the above where sequence ids are reserved externally. **Parameters:** **seq\_ids** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) **Return type:** None ### `fetch()` {#max.nn.kv_cache.manager.KVCacheManager.fetch} > abstract fetch(batch, num\_steps=1) Returns blocks and other inputs to kv cache kernel for given sequence ids and prompts. **Parameters:** * **batch** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `T` `]` ) * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*KVCacheInputs*](#max.nn.kv_cache.manager.KVCacheInputs)] ### `increment_cache_lengths()` {#max.nn.kv_cache.manager.KVCacheManager.increment_cache_lengths} > increment\_cache\_lengths(kv\_cache\_inputs, prev\_model\_inputs) Prepare the inputs for a multistep execution, generally by incrementing the cache lengths. This should not require a device synchronization, as this would defeat the purpose of multistep execution. This should also not update the cache lengths in our manager, this batch is still considered in-progress. **Parameters:** * **kv\_cache\_inputs** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`RaggedKVCacheInputs`](#max.nn.kv_cache.manager.RaggedKVCacheInputs) `]` `|` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`PaddedKVCacheInputs`](#max.nn.kv_cache.manager.PaddedKVCacheInputs) `]` ) * **prev\_model\_inputs** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*RaggedKVCacheInputs*](#max.nn.kv_cache.manager.RaggedKVCacheInputs)] | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*PaddedKVCacheInputs*](#max.nn.kv_cache.manager.PaddedKVCacheInputs)] ### `infer_optimal_batch_size()` {#max.nn.kv_cache.manager.KVCacheManager.infer_optimal_batch_size} > abstract classmethod infer\_optimal\_batch\_size(params, max\_seq\_len, num\_layers, available\_cache\_memory, devices, \*\*kwargs) Returns the estimated optimal batch size for the kv cache. **Parameters:** * **params** ([`KVCacheParams`](cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **num\_layers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Device`](../../driver.md#max.driver.Device) `]` ) * **kwargs** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `input_symbols()` {#max.nn.kv_cache.manager.KVCacheManager.input_symbols} > abstract input\_symbols() Returns the input symbols for the kv cache manager. **Return type:** [*Sequence*](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[*KVCacheInputSymbols*](#max.nn.kv_cache.manager.KVCacheInputSymbols)] ### `num_kv_inputs()` {#max.nn.kv_cache.manager.KVCacheManager.num_kv_inputs} > num\_kv\_inputs() Returns the default number of KV cache inputs for KV managers. Subclasses with a different number of KV cache inputs should override this method and [`increment_cache_lengths`](#max.nn.kv_cache.manager.KVCacheManager.increment_cache_lengths). **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `release()` {#max.nn.kv_cache.manager.KVCacheManager.release} > release(seq\_id) Release `seq_id` provided, marking this sequence as complete. This returns the `seq_id` back to the available pool of cache memory, allowing it to be reused when a new sequence is claimed. **Parameters:** **seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `slots_remaining` {#max.nn.kv_cache.manager.KVCacheManager.slots_remaining} > property slots\_remaining: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] The outstanding cache slots available. ### `step()` {#max.nn.kv_cache.manager.KVCacheManager.step} > step(batch) Commit the new tokens into the prefix cache. This is a no-op if prefix caching is disabled. **Parameters:** **batch** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `T` `]` ) **Return type:** None ## `PaddedKVCacheInputs` {#max.nn.kv_cache.manager.PaddedKVCacheInputs} > class max.nn.kv\_cache.manager.PaddedKVCacheInputs(k\_cache, v\_cache, start\_pos, null\_op) `PaddedKVCacheInputs` is a class that holds the inputs for KV cache when used together with padded tensors. **Parameters:** * **k\_cache** ([`Tensor`](../../driver.md#max.driver.Tensor) ) * **v\_cache** ([`Tensor`](../../driver.md#max.driver.Tensor) ) * **start\_pos** ([`Tensor`](../../driver.md#max.driver.Tensor) ) * **null\_op** ([`Tensor`](../../driver.md#max.driver.Tensor) ) ### `k_cache` {#max.nn.kv_cache.manager.PaddedKVCacheInputs.k_cache} > k\_cache: [Tensor](../../driver.md#max.driver.Tensor) ### `null_op` {#max.nn.kv_cache.manager.PaddedKVCacheInputs.null_op} > null\_op: [Tensor](../../driver.md#max.driver.Tensor) ### `start_pos` {#max.nn.kv_cache.manager.PaddedKVCacheInputs.start_pos} > start\_pos: [Tensor](../../driver.md#max.driver.Tensor) ### `v_cache` {#max.nn.kv_cache.manager.PaddedKVCacheInputs.v_cache} > v\_cache: [Tensor](../../driver.md#max.driver.Tensor) ## `RaggedKVCacheInputs` {#max.nn.kv_cache.manager.RaggedKVCacheInputs} > class max.nn.kv\_cache.manager.RaggedKVCacheInputs(blocks, cache\_lengths, lookup\_table, max\_lengths) `RaggedKVCacheInputs` is a class that holds the inputs for KV cache when used together with ragged tensors. **Parameters:** * **blocks** ([`Tensor`](../../driver.md#max.driver.Tensor) ) * **cache\_lengths** ([`Tensor`](../../driver.md#max.driver.Tensor) ) * **lookup\_table** ([`Tensor`](../../driver.md#max.driver.Tensor) ) * **max\_lengths** ([`Tensor`](../../driver.md#max.driver.Tensor) ) ### `blocks` {#max.nn.kv_cache.manager.RaggedKVCacheInputs.blocks} > blocks: [Tensor](../../driver.md#max.driver.Tensor) ### `cache_lengths` {#max.nn.kv_cache.manager.RaggedKVCacheInputs.cache_lengths} > cache\_lengths: [Tensor](../../driver.md#max.driver.Tensor) ### `lookup_table` {#max.nn.kv_cache.manager.RaggedKVCacheInputs.lookup_table} > lookup\_table: [Tensor](../../driver.md#max.driver.Tensor) ### `max_lengths` {#max.nn.kv_cache.manager.RaggedKVCacheInputs.max_lengths} > max\_lengths: [Tensor](../../driver.md#max.driver.Tensor) --- ## layer ## `Layer` {#max.nn.layer.Layer} > class max.nn.layer.Layer #### Deprecated Deprecated since version 25.2.. Base class for neural network components. Use [`Module`](#max.nn.layer.Module) instead. Provides functionality for adding hooks to the call function of each layer to support testing, debugging or profiling. ## `LayerList` {#max.nn.layer.LayerList} > class max.nn.layer.LayerList(layers) Stores a list of layers. Can be used as a regular python list. **Parameters:** **layers** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Layer`](#max.nn.layer.Layer) `]` ) ### `append()` {#max.nn.layer.LayerList.append} > append(layer) **Parameters:** **layer** ([`Layer`](#max.nn.layer.Layer) ) ### `extend()` {#max.nn.layer.LayerList.extend} > extend(layer) **Parameters:** **layer** ([`Layer`](#max.nn.layer.Layer) ) ### `insert()` {#max.nn.layer.LayerList.insert} > insert(i, layer) **Parameters:** **layer** ([`Layer`](#max.nn.layer.Layer) ) ### `sublayers` {#max.nn.layer.LayerList.sublayers} > property sublayers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.layer.Module)] ## `Module` {#max.nn.layer.Module} > class max.nn.layer.Module Base class for model components with weight management. Provides functionality to create custom layers and construct networks with automatic weight tracking. The following example uses the [`Module`](#max.nn.layer.Module) class to create custom layers and build a neural network: ```python from max import nn from max.dtype import DType from max.graph import Weight, ops, DeviceRef class Linear(nn.Module): def __init__(self, in_dims, out_dims): super().__init__() self.weight = Weight("weight", DType.float32, (in_dim, out_dim), DeviceRef.CPU()) def __call__(self, x): return x @ self.weight.T class MLP(nn.Module): def __init__(self): self.up = Linear(5, 10) self.gate = Linear(5, 10) self.down = Linear(10, 5) def __call__(self, x): return self.down(ops.silu(self.gate(x)) + self.up(x)) model = MLP() print(model.state_dict()) # {"up.weight": Tensor([5, 10]), ...} ``` Constructing a graph without [`Module`](#max.nn.layer.Module) can result in name collisions with the weights (in this example, there would be three weights with the name Weight). With [`Module`](#max.nn.layer.Module), you can use [`state_dict()`](#max.nn.layer.Module.state_dict) or [`load_state_dict()`](#max.nn.layer.Module.load_state_dict) to initialize or set the weights values, and finalize the weight names to be unique within the model. ### `build_subgraph()` {#max.nn.layer.Module.build_subgraph} > build\_subgraph(name, input\_types, weight\_prefix='') Builds a subgraph for this module. This method creates a subgraph that encapsulates the module’s logic, handling input types, weights, and creating a graph with the module’s computation. Once the subgraph is built, it can be called using the `ops.call` op. **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The name of the subgraph to create. * **input\_types** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Type`](../graph/type.md#max.graph.type.Type) `|` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Type`](../graph/type.md#max.graph.type.Type) `]` `]` ) – A list of input types for the subgraph. Each element can be either a single `Type` or a list of `Type` objects. * **weight\_prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Optional prefix for weight names in the subgraph. If provided, weights with names starting with this prefix will have their names modified by removing the prefix and will be marked as placeholders. **Returns:** The created subgraph containing the module’s computation. **Return type:** `Graph` #### NOTE * Placeholder weights will require the `prefix` attribute of `ops.call` to be set. ### `layer_weights` {#max.nn.layer.Module.layer_weights} > property layer\_weights: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Weight](../graph/Weight.md#max.graph.Weight)] ### `load_state_dict()` {#max.nn.layer.Module.load_state_dict} > load\_state\_dict(state\_dict, \*, override\_quantization\_encoding=False, weight\_alignment=None, strict=True) Sets the values of all weights in this model. **Parameters:** * **state\_dict** ([`Mapping`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Mapping) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`DLPackArray`](../driver.md#max.driver.DLPackArray) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` [`WeightData`](../graph/weights.md#max.graph.weights.WeightData) `]` ) – A map from weight name to a numpy array or [`max.driver.Tensor`](../driver.md#max.driver.Tensor). * **override\_quantization\_encoding** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to override the weight quantization based on the loaded value. * **weight\_alignment** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – If specified, overrides the alignment for each weight in the Module. If left as None, each value in state\_dict must be aligned by the default dtype alignment. * **strict** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If True, raises an error if any keys in state\_dict were not used by the Module. **Raises:** * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If any weight in the model is not present in the state dict. * [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If strict is True and state\_dict contains keys not used by the Module. **Return type:** None ### `raw_state_dict()` {#max.nn.layer.Module.raw_state_dict} > raw\_state\_dict() Returns all weights objects in the model. Unlike [`state_dict`](#max.nn.layer.Module.state_dict), this returns [`max.graph.Weight`](../graph/Weight.md#max.graph.Weight) objects instead of the assigned values. Some parameters inside the `Weight` can be configured before a graph is built. Do not change these attributes after building a graph: * [`align`](../graph/Weight.md#max.graph.Weight.align) * [`dtype`](../graph/Weight.md#max.graph.Weight.dtype) * [`quantization_encoding`](../graph/Weight.md#max.graph.Weight.quantization_encoding) * [`shape`](../graph/Weight.md#max.graph.Weight.shape) **Returns:** Map from weight name to the [`max.graph.Weight`](../graph/Weight.md#max.graph.Weight) object. **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [*Weight*](../graph/Weight.md#max.graph.Weight)] ### `set_shared_weight()` {#max.nn.layer.Module.set_shared_weight} > set\_shared\_weight(name, weight) **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **weight** ([`Weight`](../graph/Weight.md#max.graph.Weight) ) ### `state_dict()` {#max.nn.layer.Module.state_dict} > state\_dict(auto\_initialize=True) Returns values of all weights in the model. The values returned are the same as the values set in [`load_state_dict`](#max.nn.layer.Module.load_state_dict). If [`load_state_dict`](#max.nn.layer.Module.load_state_dict) has not been called and none of the weights have values, then they are initialized to zero. **Parameters:** **auto\_initialize** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Determines whether to initialize weights to zero if the weight value has not been loaded. If this is False, a ValueError is raised if an uninitialized weight is found. **Returns:** Map from weight name to the weight value (can be numpy array or [`max.driver.Tensor`](../driver.md#max.driver.Tensor)). **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [*DLPackArray*](../driver.md#max.driver.DLPackArray) | [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)] ### `sublayers` {#max.nn.layer.Module.sublayers} > property sublayers: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Module](#max.nn.layer.Module)] ## `Shardable` {#max.nn.layer.Shardable} > class max.nn.layer.Shardable(\*args, \*\*kwargs) Protocol for objects that support sharding across multiple devices. This protocol defines the interface that all shardable components (like Linear layers and Weight objects) must implement to participate in distributed computation. ### `shard()` {#max.nn.layer.Shardable.shard} > shard(shard\_idx, device) Creates a sharded view of this object for a specific device. **Parameters:** * **shard\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The index of the shard (0 to num\_devices-1). * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) – The device where this shard should reside. **Returns:** A sharded instance of this object. **Return type:** [*Shardable*](#max.nn.layer.Shardable) ### `sharding_strategy` {#max.nn.layer.Shardable.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Gets the weight sharding strategy. ## `add_layer_hook()` {#max.nn.layer.add_layer_hook} > max.nn.layer.add\_layer\_hook(fn) Adds a hook to call a function after each layer’s `__call__`. The function will be passed four inputs: * layer * input\_args * input\_kwargs * outputs The function can either return None or new outputs that will replace the layer returned outputs. Note that input and outputs contain graph Values, which show limited information (like [`shape`](../graph/TensorValue.md#max.graph.TensorValue.shape) and [`dtype`](../graph/TensorValue.md#max.graph.TensorValue.dtype)). You can still see the computed values if you include the Value in the `graph.ops.output` op, or call `graph.ops.print`. Example of printing debug inputs: ```python def print_info(layer, args, kwargs, outputs): print("Layer:", type(layer).__name__) print("Input args:", args) print("Input kwargs:", kwargs) print("Outputs:", outputs) return outputs add_layer_hook(print_info) ``` **Parameters:** **fn** ([`Callable`](https://docs.python.org/3/library/typing.html#typing.Callable) `[` `[` [`Layer`](#max.nn.layer.Layer) `,` [`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `,` `...` `]` `,` [`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `]` `,` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `]` `,` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `]` ) **Return type:** None ## `clear_hooks()` {#max.nn.layer.clear_hooks} > max.nn.layer.clear\_hooks() Remove all hooks. ## `recursive_named_layers()` {#max.nn.layer.recursive_named_layers} > max.nn.layer.recursive\_named\_layers(parent, prefix='') Recursively walks through the layers and generates names. **Parameters:** * **parent** ([`Module`](#max.nn.layer.Module) ) * **prefix** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) **Return type:** [*Iterable*](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [*Module*](#max.nn.layer.Module)]] --- ## linear Multi-layer Perceptron. ## `ColumnParallelLinear` {#max.nn.linear.ColumnParallelLinear} > class max.nn.linear.ColumnParallelLinear(in\_dim, out\_dim, dtype, devices, tied\_weight=None, \*\*kwargs) A Linear layer where the weight and bias are sharded onto multiple devices. This layer first computes $y = xW_i^T + b_i$ for each device i in \[0,…, num\_devices]: ```default +-----+ +-----+ T +-----+ +-----+ | | | W_0 | | b_0 | | y_0 | GPU0 | | +-----+ +-----+ +-----+ | | | W_1 | | b_1 | | y_1 | GPU1 | x | @ +-----+ + +-----+ = +-----+ | | | W_2 | | b_2 | | y_2 | GPU2 | | +-----+ +-----+ +-----+ | | | W_3 | | b_3 | | y_3 | GPU3 +-----+ +-----+ +-----+ +-----+ ``` The values are then collected using an Allgather op, producing the same output tensor $y = xW^T + b$ on each device: ```default GPU0 GPU1 GPU2 GPU3 GPU0 GPU1 GPU2 GPU3 +-----+-----+-----+-----+ +-----+-----+-----+-----+ | y_0 | - | - | - | | y_0 | y_0 | y_0 | y_0 | +-----+-----+-----+-----+ +-----+-----+-----+-----+ | - | y_1 | - | - | | y_1 | y_1 | y_1 | y_1 | +-----+-----+-----+-----+ -- Allgather --> +-----+-----+-----+-----+ | - | - | y_2 | - | | y_2 | y_2 | y_2 | y_2 | +-----+-----+-----+-----+ +-----+-----+-----+-----+ | - | - | - | y_3 | | y_3 | y_3 | y_3 | y_3 | +-----+-----+-----+-----+ +-----+-----+-----+-----+ ``` Example usage: ```python from max.dtype import DType from max.graph import DeviceRef from max.nn import ColumnParallelLinear num_devices = 4 distributed_linear = ColumnParallelLinear( in_dim, out_dim, DType.float32, devices=[DeviceRef.GPU(i) for i in range(num_devices)], ) ``` **Parameters:** * **in\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of the input space. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of the output space. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type for both weights and bias. * **devices** (`Sequence` `[` [`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) `]` ) – The target devices for computation. Weights remain on CPU until sharded and moved to device during computation. * **tied\_weight** ([`Weight`](../graph/Weight.md#max.graph.Weight) `|` `None` ) ## `DistributedMLP` {#max.nn.linear.DistributedMLP} > class max.nn.linear.DistributedMLP(\*args, \*\*kwargs) A distributed multi-layer perceptron. This class has the same state keys as the non-distributed MLP Layer. **Parameters:** * **dtype** – DType to use for the layer weights, which should match the input dtype. * **quantization\_encoding** – Quantization encoding of the layer weights. * **hidden\_dim** – The last dimension of the layer input. * **feed\_forward\_length** – Size of dimension used to project the inputs. * **linear\_cls** – Linear class to use to create the projection layers. * **devices** – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices. * **activation\_function** – Activation function to use. Options are: * “silu” * “gelu” * “gelu\_tanh” * “relu” * “tanh” * “sigmoid” ## `Float8Config` {#max.nn.linear.Float8Config} > class max.nn.linear.Float8Config(input\_scale, weight\_scale, mlp\_in\_float8, attn\_qkv\_in\_float8, embedding\_output\_dtype=None, quant\_method=None) Configures float8 quantization settings for a layer or model section. **Parameters:** * **input\_scale** ([`Float8InputScaleSpec`](#max.nn.linear.Float8InputScaleSpec) ) * **weight\_scale** ([`Float8WeightScaleSpec`](#max.nn.linear.Float8WeightScaleSpec) ) * **mlp\_in\_float8** ([`set`](https://docs.python.org/3/library/stdtypes.html#set) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **attn\_qkv\_in\_float8** ([`set`](https://docs.python.org/3/library/stdtypes.html#set) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **embedding\_output\_dtype** ([`DType`](../dtype.md#max.dtype.DType) `|` `None` ) * **quant\_method** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) ### `attn_qkv_in_float8` {#max.nn.linear.Float8Config.attn_qkv_in_float8} > attn\_qkv\_in\_float8: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] Set of layer indices with attention QKV projections in float8. QKV projections are considered to be either “all quantized” or all not quantized per layer. So either all of {q,k,v,o}\_proj are float8, or all bfloat16. ### `embedding_output_dtype` {#max.nn.linear.Float8Config.embedding_output_dtype} > embedding\_output\_dtype: [DType](../dtype.md#max.dtype.DType) | [None](https://docs.python.org/3/library/constants.html#None) = None The data type of the output from the embedding layer. ### `input_scale` {#max.nn.linear.Float8Config.input_scale} > input\_scale: [Float8InputScaleSpec](#max.nn.linear.Float8InputScaleSpec) Specification for input activation scaling. ### `is_dynamic` {#max.nn.linear.Float8Config.is_dynamic} > property is\_dynamic: [bool](https://docs.python.org/3/library/functions.html#bool) Returns true if this input scale is dynamic. ### `is_static` {#max.nn.linear.Float8Config.is_static} > property is\_static: [bool](https://docs.python.org/3/library/functions.html#bool) Returns true if this input scale is static. ### `mlp_in_float8` {#max.nn.linear.Float8Config.mlp_in_float8} > mlp\_in\_float8: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] Set of layer indices with MLPs in float8. MLPs are considered to be either “all quantized” or all not quantized per layer. So either all of gate proj, down proj, and up proj are float8, or all bfloat16. ### `quant_method` {#max.nn.linear.Float8Config.quant_method} > quant\_method: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None The quantization method used (e.g., “fbgemm\_fp8”). ### `weight_scale` {#max.nn.linear.Float8Config.weight_scale} > weight\_scale: [Float8WeightScaleSpec](#max.nn.linear.Float8WeightScaleSpec) Specification for weight scaling. ## `Float8InputScaleSpec` {#max.nn.linear.Float8InputScaleSpec} > class max.nn.linear.Float8InputScaleSpec(granularity, origin, dtype, activation\_scale\_ub=None) Specifies how input activations are scaled for float8 quantization. **Parameters:** * **granularity** ([`Float8ScaleGranularity`](#max.nn.linear.Float8ScaleGranularity) ) * **origin** ([`Float8ScaleOrigin`](#max.nn.linear.Float8ScaleOrigin) ) * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) * **activation\_scale\_ub** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) ### `activation_scale_ub` {#max.nn.linear.Float8InputScaleSpec.activation_scale_ub} > activation\_scale\_ub: [float](https://docs.python.org/3/library/functions.html#float) | [None](https://docs.python.org/3/library/constants.html#None) = None An optional upper bound for dynamic activation scaling. ### `dtype` {#max.nn.linear.Float8InputScaleSpec.dtype} > dtype: [DType](../dtype.md#max.dtype.DType) The data type of the input scale factor(s). ### `granularity` {#max.nn.linear.Float8InputScaleSpec.granularity} > granularity: [Float8ScaleGranularity](#max.nn.linear.Float8ScaleGranularity) The granularity of the input scale factor application. ### `origin` {#max.nn.linear.Float8InputScaleSpec.origin} > origin: [Float8ScaleOrigin](#max.nn.linear.Float8ScaleOrigin) The origin (static or dynamic) of the input scale factor. ## `Float8ScaleGranularity` {#max.nn.linear.Float8ScaleGranularity} > class max.nn.linear.Float8ScaleGranularity(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Specifies the granularity of the quantization scale factor. Determines whether a scale factor applies per-tensor, per-row (often for weights), per-column, or per-block within a tensor. ### `BLOCK` {#max.nn.linear.Float8ScaleGranularity.BLOCK} > BLOCK = 'block' ### `COLWISE` {#max.nn.linear.Float8ScaleGranularity.COLWISE} > COLWISE = 'colwise' ### `ROWWISE` {#max.nn.linear.Float8ScaleGranularity.ROWWISE} > ROWWISE = 'rowwise' ### `TENSOR` {#max.nn.linear.Float8ScaleGranularity.TENSOR} > TENSOR = 'tensor' ## `Float8ScaleOrigin` {#max.nn.linear.Float8ScaleOrigin} > class max.nn.linear.Float8ScaleOrigin(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) Specifies whether the quantization scale is determined statically or dynamically. STATIC scales are pre-computed and loaded with the model weights. DYNAMIC scales are computed at runtime based on the input data. ### `DYNAMIC` {#max.nn.linear.Float8ScaleOrigin.DYNAMIC} > DYNAMIC = 'dynamic' ### `STATIC` {#max.nn.linear.Float8ScaleOrigin.STATIC} > STATIC = 'static' ## `Float8WeightScaleSpec` {#max.nn.linear.Float8WeightScaleSpec} > class max.nn.linear.Float8WeightScaleSpec(granularity, dtype) Specifies how weights are scaled for float8 quantization. **Parameters:** * **granularity** ([`Float8ScaleGranularity`](#max.nn.linear.Float8ScaleGranularity) ) * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) ### `dtype` {#max.nn.linear.Float8WeightScaleSpec.dtype} > dtype: [DType](../dtype.md#max.dtype.DType) The data type of the weight scale factor(s). ### `granularity` {#max.nn.linear.Float8WeightScaleSpec.granularity} > granularity: [Float8ScaleGranularity](#max.nn.linear.Float8ScaleGranularity) The granularity of the weight scale factor application. ### `is_block` {#max.nn.linear.Float8WeightScaleSpec.is_block} > property is\_block: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is block-wise. ### `is_colwise` {#max.nn.linear.Float8WeightScaleSpec.is_colwise} > property is\_colwise: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is column-wise. ### `is_rowwise` {#max.nn.linear.Float8WeightScaleSpec.is_rowwise} > property is\_rowwise: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is row-wise. ### `is_tensor` {#max.nn.linear.Float8WeightScaleSpec.is_tensor} > property is\_tensor: [bool](https://docs.python.org/3/library/functions.html#bool) Whether the weight scale granularity is per-tensor. ## `GPTQLinear` {#max.nn.linear.GPTQLinear} > class max.nn.linear.GPTQLinear(in\_dim, out\_dim, dtype, device, has\_bias=False, quantization\_encoding=None, quantization\_config=None, float8\_config=None) A Linear layer for GPTQ encoding Initializes the linear layer with weights and optional bias with GPTQ quantization. **Parameters:** * **in\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of the input space. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of the output space. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type for both weights and bias. * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) – The target device for computation. Weights remain on CPU until moved during computation. * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – When [`True`](https://docs.python.org/3/library/constants.html#True), adds a bias vector to the layer. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). * **quantization\_encoding** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) – The quantization encoding of the weights. * **quantization\_config** ([`QuantizationConfig`](../graph/quantization.md#max.graph.quantization.QuantizationConfig) `|` `None` ) – Extra config for the weight quantization. * **float8\_config** ([`Float8Config`](#max.nn.linear.Float8Config) `|` `None` ) ## `GPTQLinearV1` {#max.nn.linear.GPTQLinearV1} > class max.nn.linear.GPTQLinearV1(weight, bias=None, quantization\_encoding=None, quantization\_config=None, perm\_idx=None) A Linear layer for GPTQ encoding **Parameters:** * **weight** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **quantization\_config** ([`QuantizationConfig`](../graph/quantization.md#max.graph.quantization.QuantizationConfig) `|` `None` ) * **perm\_idx** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) ### `perm_idx` {#max.nn.linear.GPTQLinearV1.perm_idx} > perm\_idx: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `quantization_config` {#max.nn.linear.GPTQLinearV1.quantization_config} > quantization\_config: [QuantizationConfig](../graph/quantization.md#max.graph.quantization.QuantizationConfig) | [None](https://docs.python.org/3/library/constants.html#None) = None ## `Linear` {#max.nn.linear.Linear} > class max.nn.linear.Linear(in\_dim, out\_dim, dtype, device, has\_bias=False, quantization\_encoding=None, float8\_config=None, name=None, clip\_weight=None) Applies a linear transformation to incoming data: $y = xW^T + b$. This layer implements a fully connected layer where inputs are multiplied by a weight matrix and optionally added with a bias vector. Both weights and bias initially reside on CPU, and the model init phase moves them to [`device`](#max.nn.linear.Linear.device). Example: ```python linear_layer = Linear( in_dim=256, out_dim=128, dtype=DType.float32, device=DeviceRef.GPU(), name="linear", has_bias=True ) input_tensor: TensorValue output = linear_layer(input_tensor) ``` Initializes the linear layer with weights and optional bias. **Parameters:** * **in\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of the input space. * **out\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The dimensionality of the output space. * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – The data type for both weights and bias. * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) – The target device for computation. Weights remain on CPU until moved during computation. * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) – Base name for weights (appended with `.weight` and `.bias` if applicable). * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – When [`True`](https://docs.python.org/3/library/constants.html#True), adds a bias vector to the layer. Defaults to [`False`](https://docs.python.org/3/library/constants.html#False). * **quantization\_encoding** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **float8\_config** ([`Float8Config`](#max.nn.linear.Float8Config) `|` `None` ) * **clip\_weight** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) ### `bias` {#max.nn.linear.Linear.bias} > bias: [Weight](../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional bias vector stored on CPU with shape (out\_dim,). Model init moves the bias to [`device`](#max.nn.linear.Linear.device) if present. ### `device` {#max.nn.linear.Linear.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) The device where matrix operations are performed. ### `input_scale` {#max.nn.linear.Linear.input_scale} > input\_scale: [Weight](../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional input scale stored on CPU with shape (). Model init moves the input\_scale to [`device`](#max.nn.linear.Linear.device) if present. ### `shard()` {#max.nn.linear.Linear.shard} > shard(shard\_idx, device) Creates a sharded view of this Linear layer for a specific device. **Parameters:** * **shard\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The index of the shard (0 to num\_devices-1). * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) – The device where this shard should reside. **Returns:** A sharded Linear instance. **Return type:** [*Linear*](#max.nn.linear.Linear) ### `sharding_strategy` {#max.nn.linear.Linear.sharding_strategy} > property sharding\_strategy: ShardingStrategy | [None](https://docs.python.org/3/library/constants.html#None) Get the weight sharding strategy. ### `weight` {#max.nn.linear.Linear.weight} > weight: [Weight](../graph/Weight.md#max.graph.Weight) The weight matrix stored on CPU with shape (out\_dim, in\_dim). Model init transposes the weight and moves it to [`device`](#max.nn.linear.Linear.device). ### `weight_scale` {#max.nn.linear.Linear.weight_scale} > weight\_scale: [Weight](../graph/Weight.md#max.graph.Weight) | [None](https://docs.python.org/3/library/constants.html#None) = None The optional weight scale stored on CPU with shape () or (N,). Model init moves the weight\_scale to [`device`](#max.nn.linear.Linear.device) if present. ## `LinearV1` {#max.nn.linear.LinearV1} > class max.nn.linear.LinearV1(weight, bias=None) A unified linear layer that delegates to either regular or quantized implementation. Deprecated: Use Linear instead. **Parameters:** * **weight** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) ### `bias` {#max.nn.linear.LinearV1.bias} > bias: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `create()` {#max.nn.linear.LinearV1.create} > classmethod create(dtype, quantization\_encoding, in\_features, out\_features, weights, bias=None, quantization\_config=None) Factory method to create a Linear layer with appropriate implementation. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) * **quantization\_encoding** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) * **in\_features** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **out\_features** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **weights** ([`Weights`](../graph/weights.md#max.graph.weights.Weights) `|` [`Weight`](../graph/Weight.md#max.graph.Weight) ) * **bias** ([`Weights`](../graph/weights.md#max.graph.weights.Weights) `|` [`Weight`](../graph/Weight.md#max.graph.Weight) `|` `None` ) * **quantization\_config** ([`QuantizationConfig`](../graph/quantization.md#max.graph.quantization.QuantizationConfig) `|` `None` ) **Return type:** [*LinearV1*](#max.nn.linear.LinearV1) ### `weight` {#max.nn.linear.LinearV1.weight} > weight: Value\[TensorType] | [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../graph/type.md#max.graph.type.Shape) | [Dim](../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ## `MLP` {#max.nn.linear.MLP} > class max.nn.linear.MLP(dtype, quantization\_encoding, hidden\_dim, feed\_forward\_length, devices, linear\_cls=\, has\_bias=False, activation\_function='silu', float8\_config=None) Simple multi-layer perceptron composed of three linear layers. Defaults to SiLU activation function. **Parameters:** * **dtype** ([`DType`](../dtype.md#max.dtype.DType) ) – DType to use for the layer weights, which should match the input dtype. * **quantization\_encoding** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) – Quantization encoding of the layer weights. * **hidden\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The last dimension of the layer input. * **feed\_forward\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Size of dimension used to project the inputs. * **linear\_cls** (`Callable` `[` `...` `,` [`Linear`](#max.nn.linear.Linear) `]` ) – Linear class to use to create the projection layers. * **devices** (`Sequence` `[` [`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) `]` ) – Devices to run the MLP layer. If multiple are provided, the first device is used instead. Use DistributedMLP to use all devices. * **activation\_function** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Activation function to use. Options are: * “silu” * “gelu” * “gelu\_tanh” * “relu” * “tanh” * “sigmoid” * **has\_bias** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **float8\_config** ([`Float8Config`](#max.nn.linear.Float8Config) `|` `None` ) ## `MLPV1` {#max.nn.linear.MLPV1} > class max.nn.linear.MLPV1(gate\_proj, down\_proj, up\_proj) Simple multi-layer perceptron composed of three linear layers. Uses SiLU activation function. **Parameters:** * **gate\_proj** ([`LinearV1`](#max.nn.linear.LinearV1) ) * **down\_proj** ([`LinearV1`](#max.nn.linear.LinearV1) ) * **up\_proj** ([`LinearV1`](#max.nn.linear.LinearV1) ) ### `down_proj` {#max.nn.linear.MLPV1.down_proj} > down\_proj: [LinearV1](#max.nn.linear.LinearV1) ### `gate_proj` {#max.nn.linear.MLPV1.gate_proj} > gate\_proj: [LinearV1](#max.nn.linear.LinearV1) ### `up_proj` {#max.nn.linear.MLPV1.up_proj} > up\_proj: [LinearV1](#max.nn.linear.LinearV1) ## `QLinearV1` {#max.nn.linear.QLinearV1} > class max.nn.linear.QLinearV1(weight, bias=None, quantization\_encoding=None) A quantized fully connected layer. **Parameters:** * **weight** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **bias** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **quantization\_encoding** ([`QuantizationEncoding`](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) `|` `None` ) ### `quantization_encoding` {#max.nn.linear.QLinearV1.quantization_encoding} > quantization\_encoding: [QuantizationEncoding](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) = None --- ## group_norm Group Normalization implementation using the graph API. ## `GroupNorm` {#max.nn.norm.group_norm.GroupNorm} > class max.nn.norm.group\_norm.GroupNorm(num\_groups, num\_channels, eps=1e-05, affine=True, device=gpu:0) Group normalization block. Divides channels into groups and computes normalization stats per group. Follows the implementation pattern from PyTorch’s group\_norm. **Parameters:** * **num\_groups** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of groups to separate the channels into * **num\_channels** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of input channels * **eps** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – Small constant added to denominator for numerical stability * **affine** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – If True, apply learnable affine transform parameters * **device** ([`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) ) --- ## norm ## Modules * [`group_norm`](/max/api/python/nn/norm/group_norm) * [`layer_norm`](/max/api/python/nn/norm/layer_norm) * [`rms_norm`](/max/api/python/nn/norm/rms_norm) --- ## layer_norm Layer Normalization layer. ## `LayerNorm` {#max.nn.norm.layer_norm.LayerNorm} > class max.nn.norm.layer\_norm.LayerNorm(dims, device, dtype, eps=1e-05, use\_bias=True) Layer normalization block. **Parameters:** * **dims** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **device** ([`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) ) * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) * **eps** ([`float`](https://docs.python.org/3/library/functions.html#float) ) ## `LayerNormV1` {#max.nn.norm.layer_norm.LayerNormV1} > class max.nn.norm.layer\_norm.LayerNormV1(weight, bias=None, eps=1e-06) Layer normalization block. Deprecated: Use LayerNorm instead. **Parameters:** * **weight** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) ) * **bias** ([`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` `None` ) * **eps** ([`float`](https://docs.python.org/3/library/functions.html#float) ) ### `bias` {#max.nn.norm.layer_norm.LayerNormV1.bias} > bias: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `eps` {#max.nn.norm.layer_norm.LayerNormV1.eps} > eps: [float](https://docs.python.org/3/library/functions.html#float) = 1e-06 ### `weight` {#max.nn.norm.layer_norm.LayerNormV1.weight} > weight: [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) --- ## rms_norm Normalization layer. ## `DistributedRMSNorm` {#max.nn.norm.rms_norm.DistributedRMSNorm} > class max.nn.norm.rms\_norm.DistributedRMSNorm(\*args, devices, \*\*kwargs) **Parameters:** **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` ) ## `RMSNorm` {#max.nn.norm.rms_norm.RMSNorm} > class max.nn.norm.rms\_norm.RMSNorm(dim, dtype, eps=1e-06, weight\_offset=0.0, multiply\_before\_cast=True) Computes the Root Mean Square normalization on inputs. **Parameters:** * **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Size of last dimension of the expected input. * **eps** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – Value added to denominator for numerical stability. * **weight\_offset** ([`float`](https://docs.python.org/3/library/functions.html#float) ) – Constant offset added to the learned weights at runtime. For Gemma-style RMSNorm, this should be set to 1.0. * **multiply\_before\_cast** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – True if we multiply the inputs by the learned weights before casting to the input type (Gemma3-style). False if we cast the inputs to the input type first, then multiply by the learned weights (Llama-style). * **dtype** ([`DType`](../../dtype.md#max.dtype.DType) ) ## `RMSNormV1` {#max.nn.norm.rms_norm.RMSNormV1} > class max.nn.norm.rms\_norm.RMSNormV1(weight, eps=1e-06, weight\_offset=0.0, multiply\_before\_cast=True) Computes the Root Mean Square normalization on inputs. Deprecated: Use RMSNorm instead. **Parameters:** * **weight** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **eps** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **weight\_offset** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **multiply\_before\_cast** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `eps` {#max.nn.norm.rms_norm.RMSNormV1.eps} > eps: [float](https://docs.python.org/3/library/functions.html#float) = 1e-06 ### `multiply_before_cast` {#max.nn.norm.rms_norm.RMSNormV1.multiply_before_cast} > multiply\_before\_cast: [bool](https://docs.python.org/3/library/functions.html#bool) = True ### `weight` {#max.nn.norm.rms_norm.RMSNormV1.weight} > weight: Value\[TensorType] | [TensorValue](../../graph/TensorValue.md#max.graph.TensorValue) | [Shape](../../graph/type.md#max.graph.type.Shape) | [Dim](../../graph/type.md#max.graph.type.Dim) | [int](https://docs.python.org/3/library/functions.html#int) | [float](https://docs.python.org/3/library/functions.html#float) | [integer](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) | [floating](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) | [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `weight_offset` {#max.nn.norm.rms_norm.RMSNormV1.weight_offset} > weight\_offset: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 --- ## rotary_embedding The rope embedding used within the model. ## `DeepseekYarnRopeScalingParams` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams} > class max.nn.rotary\_embedding.DeepseekYarnRopeScalingParams(scaling\_factor: [float](https://docs.python.org/3/library/functions.html#float), original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int), beta\_fast: [int](https://docs.python.org/3/library/functions.html#int), beta\_slow: [int](https://docs.python.org/3/library/functions.html#int), mscale: [float](https://docs.python.org/3/library/functions.html#float), mscale\_all\_dim: [float](https://docs.python.org/3/library/functions.html#float)) **Parameters:** * **scaling\_factor** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **original\_max\_position\_embeddings** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **beta\_fast** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **beta\_slow** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **mscale** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **mscale\_all\_dim** ([`float`](https://docs.python.org/3/library/functions.html#float) ) ### `beta_fast` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams.beta_fast} > beta\_fast: [int](https://docs.python.org/3/library/functions.html#int) Fast interpolation rate. ### `beta_slow` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams.beta_slow} > beta\_slow: [int](https://docs.python.org/3/library/functions.html#int) Slow interpolation rate. ### `mscale` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams.mscale} > mscale: [float](https://docs.python.org/3/library/functions.html#float) Scaling factor for middle frequencies. ### `mscale_all_dim` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams.mscale_all_dim} > mscale\_all\_dim: [float](https://docs.python.org/3/library/functions.html#float) Scaling factor applied to all dimensions. ### `original_max_position_embeddings` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams.original_max_position_embeddings} > original\_max\_position\_embeddings: [int](https://docs.python.org/3/library/functions.html#int) Original maximum sequence length during training. ### `scaling_factor` {#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams.scaling_factor} > scaling\_factor: [float](https://docs.python.org/3/library/functions.html#float) Scaling factor for frequency interpolation. ## `DeepseekYarnRotaryEmbedding` {#max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding} > class max.nn.rotary\_embedding.DeepseekYarnRotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, device, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None) Deepseek’s YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer. Unlike Llama3RotaryEmbedding, the dim argument here is the rope dimension of the model, not the hidden dimension. **Parameters:** * **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **theta** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) * **head\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **\_freqs\_cis** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **interleaved** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **scaling\_params** ([`DeepseekYarnRopeScalingParams`](#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams) `|` `None` ) ### `compute_scale()` {#max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding.compute_scale} > compute\_scale(user\_scale=None) **Parameters:** **user\_scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) **Return type:** [float](https://docs.python.org/3/library/functions.html#float) ### `freqs_cis_base()` {#max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes the frequency tensor for complex exponentials (cis) for a given seq\_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864). **Returns:** The frequency tensor for complex exponentials with shape (max\_seq\_len, rope\_dim // 2, 2) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ### `scaling_params` {#max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding.scaling_params} > scaling\_params: [DeepseekYarnRopeScalingParams](#max.nn.rotary_embedding.DeepseekYarnRopeScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None ## `LinearScalingParams` {#max.nn.rotary_embedding.LinearScalingParams} > class max.nn.rotary\_embedding.LinearScalingParams(factor: [float](https://docs.python.org/3/library/functions.html#float)) **Parameters:** **factor** ([`float`](https://docs.python.org/3/library/functions.html#float) ) ### `factor` {#max.nn.rotary_embedding.LinearScalingParams.factor} > factor: [float](https://docs.python.org/3/library/functions.html#float) Main scaling factor for the frequency components of the rope. ## `Llama3RopeScalingParams` {#max.nn.rotary_embedding.Llama3RopeScalingParams} > class max.nn.rotary\_embedding.Llama3RopeScalingParams(factor: [float](https://docs.python.org/3/library/functions.html#float), low\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float), high\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float), orig\_max\_position: [int](https://docs.python.org/3/library/functions.html#int)) **Parameters:** * **factor** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **low\_freq\_factor** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **high\_freq\_factor** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **orig\_max\_position** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `factor` {#max.nn.rotary_embedding.Llama3RopeScalingParams.factor} > factor: [float](https://docs.python.org/3/library/functions.html#float) Main scaling factor for the frequency components of the rope. ### `high_freq_factor` {#max.nn.rotary_embedding.Llama3RopeScalingParams.high_freq_factor} > high\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float) Factor to scale the high frequency components of the rope. ### `low_freq_factor` {#max.nn.rotary_embedding.Llama3RopeScalingParams.low_freq_factor} > low\_freq\_factor: [float](https://docs.python.org/3/library/functions.html#float) Factor to scale the low frequency components of the rope. ### `orig_max_position` {#max.nn.rotary_embedding.Llama3RopeScalingParams.orig_max_position} > orig\_max\_position: [int](https://docs.python.org/3/library/functions.html#int) The original maximum position length supported by the model. ## `Llama3RotaryEmbedding` {#max.nn.rotary_embedding.Llama3RotaryEmbedding} > class max.nn.rotary\_embedding.Llama3RotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, device, head\_dim=None, \_freqs\_cis=None, interleaved=True, scaling\_params=None) RotaryEmbedding for Llama3 that takes rope scaling into account. **Parameters:** * **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **theta** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) * **head\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **\_freqs\_cis** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **interleaved** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **scaling\_params** ([`Llama3RopeScalingParams`](#max.nn.rotary_embedding.Llama3RopeScalingParams) `|` `None` ) ### `scaling_params` {#max.nn.rotary_embedding.Llama3RotaryEmbedding.scaling_params} > scaling\_params: [Llama3RopeScalingParams](#max.nn.rotary_embedding.Llama3RopeScalingParams) | [None](https://docs.python.org/3/library/constants.html#None) = None Scaling parameters to enable llama to function with a longer context length. ## `RotaryEmbedding` {#max.nn.rotary_embedding.RotaryEmbedding} > class max.nn.rotary\_embedding.RotaryEmbedding(dim, n\_heads, theta, max\_seq\_len, device, head\_dim=None, \_freqs\_cis=None, interleaved=True) RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials. **Parameters:** * **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **theta** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) * **head\_dim** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **\_freqs\_cis** (`Value` `[` `TensorType` `]` `|` [`TensorValue`](../graph/TensorValue.md#max.graph.TensorValue) `|` [`Shape`](../graph/type.md#max.graph.type.Shape) `|` [`Dim`](../graph/type.md#max.graph.type.Dim) `|` [`int`](https://docs.python.org/3/library/functions.html#int) `|` [`float`](https://docs.python.org/3/library/functions.html#float) `|` [`integer`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.integer) `|` [`floating`](https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.floating) `|` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) * **interleaved** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `compute_scale()` {#max.nn.rotary_embedding.RotaryEmbedding.compute_scale} > compute\_scale(user\_scale=None) **Parameters:** **user\_scale** ([`float`](https://docs.python.org/3/library/functions.html#float) `|` `None` ) **Return type:** [float](https://docs.python.org/3/library/functions.html#float) ### `device` {#max.nn.rotary_embedding.RotaryEmbedding.device} > device: [DeviceRef](../graph/type.md#max.graph.type.DeviceRef) ### `dim` {#max.nn.rotary_embedding.RotaryEmbedding.dim} > dim: [int](https://docs.python.org/3/library/functions.html#int) ### `freqs_cis` {#max.nn.rotary_embedding.RotaryEmbedding.freqs_cis} > property freqs\_cis: [TensorValue](../graph/TensorValue.md#max.graph.TensorValue) ### `freqs_cis_base()` {#max.nn.rotary_embedding.RotaryEmbedding.freqs_cis_base} > freqs\_cis\_base() Computes the frequency tensor for complex exponentials (cis) for a given seq\_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864). **Returns:** The frequency tensor for complex exponentials with shape (max\_seq\_len \* 2, head\_dim / 2, 2) **Return type:** [*TensorValue*](../graph/TensorValue.md#max.graph.TensorValue) ### `head_dim` {#max.nn.rotary_embedding.RotaryEmbedding.head_dim} > head\_dim: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None head\_dim = dim // n\_heads if not specified in the config. ### `interleaved` {#max.nn.rotary_embedding.RotaryEmbedding.interleaved} > interleaved: [bool](https://docs.python.org/3/library/functions.html#bool) = True ### `max_seq_len` {#max.nn.rotary_embedding.RotaryEmbedding.max_seq_len} > max\_seq\_len: [int](https://docs.python.org/3/library/functions.html#int) The maximum sequence length for model’s input. ### `n_heads` {#max.nn.rotary_embedding.RotaryEmbedding.n_heads} > n\_heads: [int](https://docs.python.org/3/library/functions.html#int) ### `theta` {#max.nn.rotary_embedding.RotaryEmbedding.theta} > theta: [float](https://docs.python.org/3/library/functions.html#float) Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings. --- ## sequential A General sequential layer, each layer is executed with the outputs of the previous. ## `Sequential` {#max.nn.sequential.Sequential} > class max.nn.sequential.Sequential(layers) A sequential stack of layers where each layer is called by the outputs of the previous layer. **Parameters:** **layers** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`Layer`](layer.md#max.nn.layer.Layer) `]` ) --- ## distributed_transformer ## `DistributedTransformer` {#max.nn.transformer.distributed_transformer.DistributedTransformer} > class max.nn.transformer.distributed\_transformer.DistributedTransformer(dim, n\_heads, layers, norm, output, embedding, kv\_params, kv\_collection\_constructor, devices, return\_logits=ReturnLogits.LAST\_TOKEN, use\_subgraphs=False) Transformer model consisting for TransformerBlock layers. **Parameters:** * **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **layers** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DistributedTransformerBlock`](#max.nn.transformer.distributed_transformer.DistributedTransformerBlock) `]` ) * **norm** ([`DistributedRMSNorm`](../norm/rms_norm.md#max.nn.norm.rms_norm.DistributedRMSNorm) ) * **output** ([`ColumnParallelLinear`](../linear.md#max.nn.linear.ColumnParallelLinear) ) * **embedding** ([`VocabParallelEmbedding`](../embedding.md#max.nn.embedding.VocabParallelEmbedding) ) * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **kv\_collection\_constructor** ([`FetchContinuousBatchingKVCacheCollection`](../kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.FetchContinuousBatchingKVCacheCollection) `|` `FetchPagedKVCacheCollection` ) * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` ) * **return\_logits** ([`ReturnLogits`](transformer.md#max.nn.transformer.transformer.ReturnLogits) ) * **use\_subgraphs** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ## `DistributedTransformerBlock` {#max.nn.transformer.distributed_transformer.DistributedTransformerBlock} > class max.nn.transformer.distributed\_transformer.DistributedTransformerBlock(attention, mlp, attention\_norm, mlp\_norm, devices) Stack of Attention, FeedForward, and RMSNorm layers. **Parameters:** * **attention** ([`Module`](../layer.md#max.nn.layer.Module) ) * **mlp** ([`Module`](../layer.md#max.nn.layer.Module) ) * **attention\_norm** ([`DistributedRMSNorm`](../norm/rms_norm.md#max.nn.norm.rms_norm.DistributedRMSNorm) ) * **mlp\_norm** ([`DistributedRMSNorm`](../norm/rms_norm.md#max.nn.norm.rms_norm.DistributedRMSNorm) ) * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` ) ## `distribute_value()` {#max.nn.transformer.distributed_transformer.distribute_value} > max.nn.transformer.distributed\_transformer.distribute\_value(v, devices) **Parameters:** **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`DeviceRef`](../../graph/type.md#max.graph.type.DeviceRef) `]` ) ## `take()` {#max.nn.transformer.distributed_transformer.take} > max.nn.transformer.distributed\_transformer.take(it, n) Return the next *n* items from *it* as a list. **Parameters:** * **it** ([`Iterable`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Iterable) `[` [`Value`](../../graph/Value.md#max.graph.Value) `]` ) * **n** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Value*](../../graph/Value.md#max.graph.Value)] --- ## transformer ## Modules * [`distributed_transformer`](/max/api/python/nn/transformer/distributed_transformer) * [`transformer`](/max/api/python/nn/transformer/transformer) --- ## transformer ## `ReturnLogits` {#max.nn.transformer.transformer.ReturnLogits} > class max.nn.transformer.transformer.ReturnLogits(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `ALL` {#max.nn.transformer.transformer.ReturnLogits.ALL} > ALL = 'all' ### `LAST_TOKEN` {#max.nn.transformer.transformer.ReturnLogits.LAST_TOKEN} > LAST\_TOKEN = 'last\_token' ### `VARIABLE` {#max.nn.transformer.transformer.ReturnLogits.VARIABLE} > VARIABLE = 'variable' ## `Transformer` {#max.nn.transformer.transformer.Transformer} > class max.nn.transformer.transformer.Transformer(dim, n\_heads, layers, norm, output, embedding, kv\_params, kv\_collection\_constructor, return\_logits=ReturnLogits.LAST\_TOKEN, embedding\_multiplier=1.0, logits\_postprocessor=None) Transformer model consisting for TransformerBlock layers. **Parameters:** * **dim** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **n\_heads** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **layers** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `Block` `]` ) * **norm** ([`Layer`](../layer.md#max.nn.layer.Layer) ) * **output** ([`LinearV1`](../linear.md#max.nn.linear.LinearV1) `|` [`Linear`](../linear.md#max.nn.linear.Linear) ) * **embedding** ([`EmbeddingV1`](../embedding.md#max.nn.embedding.EmbeddingV1) `|` [`Embedding`](../embedding.md#max.nn.embedding.Embedding) ) * **kv\_params** ([`KVCacheParams`](../kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ) * **kv\_collection\_constructor** ([`FetchContinuousBatchingKVCacheCollection`](../kv_cache/continuous_batching_cache.md#max.nn.kv_cache.continuous_batching_cache.FetchContinuousBatchingKVCacheCollection) `|` `FetchPagedKVCacheCollection` ) * **return\_logits** ([`ReturnLogits`](#max.nn.transformer.transformer.ReturnLogits) ) * **embedding\_multiplier** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **logits\_postprocessor** (`Callable` `[` `[` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `]` `,` [`TensorValue`](../../graph/TensorValue.md#max.graph.TensorValue) `]` `|` `None` ) ## `TransformerBlock` {#max.nn.transformer.transformer.TransformerBlock} > class max.nn.transformer.transformer.TransformerBlock(attention, mlp, attention\_norm, mlp\_norm, residual\_multiplier=1.0) Stack of Attention, FeedForward, and RMSNorm layers. **Parameters:** * **attention** ([`AttentionImpl`](../attention/interfaces.md#max.nn.attention.interfaces.AttentionImpl) `|` [`AttentionImplQKV`](../attention/interfaces.md#max.nn.attention.interfaces.AttentionImplQKV) `|` [`Module`](../layer.md#max.nn.layer.Module) ) * **mlp** ([`Layer`](../layer.md#max.nn.layer.Layer) ) * **attention\_norm** ([`Layer`](../layer.md#max.nn.layer.Layer) ) * **mlp\_norm** ([`Layer`](../layer.md#max.nn.layer.Layer) ) * **residual\_multiplier** ([`float`](https://docs.python.org/3/library/functions.html#float) ) --- ## architectures ## `register_all_models()` {#max.pipelines.architectures.register_all_models} > max.pipelines.architectures.register\_all\_models() Imports model architectures, thus registering the architecture in the shared `PipelineRegistry`. --- ## config Standardized configuration for Pipeline Inference. ## `AudioGenerationConfig` {#max.pipelines.lib.config.AudioGenerationConfig} > class max.pipelines.lib.config.AudioGenerationConfig(audio\_decoder: 'str', audio\_decoder\_weights: 'str' = '', block\_sizes: 'list\[int] | None' = None, buffer: 'int' = 0, block\_causal: 'bool' = False, prepend\_prompt\_speech\_tokens: 'PrependPromptSpeechTokens' = \, prepend\_prompt\_speech\_tokens\_causal: 'bool' = False, run\_model\_test\_mode: 'bool' = False, \*\*kwargs: 'Any') **Parameters:** * **audio\_decoder** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **audio\_decoder\_weights** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **block\_sizes** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` `None` ) * **buffer** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **block\_causal** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **prepend\_prompt\_speech\_tokens** ([`PrependPromptSpeechTokens`](#max.pipelines.lib.config.PrependPromptSpeechTokens) ) * **prepend\_prompt\_speech\_tokens\_causal** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **run\_model\_test\_mode** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **kwargs** (`Any` ) ### `audio_decoder` {#max.pipelines.lib.config.AudioGenerationConfig.audio_decoder} > audio\_decoder: [str](https://docs.python.org/3/library/stdtypes.html#str) = '' The name of the audio decoder model architecture. ### `audio_decoder_config` {#max.pipelines.lib.config.AudioGenerationConfig.audio_decoder_config} > audio\_decoder\_config: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] Parameters to pass to the audio decoder model. ### `audio_decoder_weights` {#max.pipelines.lib.config.AudioGenerationConfig.audio_decoder_weights} > audio\_decoder\_weights: [str](https://docs.python.org/3/library/stdtypes.html#str) = '' The path to the audio decoder weights file. ### `block_causal` {#max.pipelines.lib.config.AudioGenerationConfig.block_causal} > block\_causal: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether prior buffered tokens should attend to tokens in the current block. Has no effect if buffer is not set. ### `block_sizes` {#max.pipelines.lib.config.AudioGenerationConfig.block_sizes} > block\_sizes: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None The block sizes to use for streaming. If this is an int, then fixed-size blocks of the given size are used If this is a list, then variable block sizes are used. ### `buffer` {#max.pipelines.lib.config.AudioGenerationConfig.buffer} > buffer: [int](https://docs.python.org/3/library/functions.html#int) = 0 The number of previous speech tokens to pass to the audio decoder on each generation step. ### `from_flags()` {#max.pipelines.lib.config.AudioGenerationConfig.from_flags} > classmethod from\_flags(audio\_flags, \*\*config\_flags) **Parameters:** * **audio\_flags** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` ) * **config\_flags** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Return type:** [*AudioGenerationConfig*](#max.pipelines.lib.config.AudioGenerationConfig) ### `prepend_prompt_speech_tokens` {#max.pipelines.lib.config.AudioGenerationConfig.prepend_prompt_speech_tokens} > prepend\_prompt\_speech\_tokens: [PrependPromptSpeechTokens](#max.pipelines.lib.config.PrependPromptSpeechTokens) = 'once' Whether the prompt speech tokens should be forwarded to the audio decoder. If “never”, the prompt tokens are not forwarded. If “once”, the prompt tokens are only forwarded on the first block. If “always”, the prompt tokens are forwarded on all blocks. ### `prepend_prompt_speech_tokens_causal` {#max.pipelines.lib.config.AudioGenerationConfig.prepend_prompt_speech_tokens_causal} > prepend\_prompt\_speech\_tokens\_causal: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether the prompt speech tokens should attend to tokens in the currently generated audio block. Has no effect if prepend\_prompt\_speech\_tokens is “never”. If False (default), the prompt tokens do not attend to the current block. If True, the prompt tokens attend to the current block. ## `PipelineConfig` {#max.pipelines.lib.config.PipelineConfig} > class max.pipelines.lib.config.PipelineConfig(\*\*kwargs) Configuration for a pipeline. WIP - Once a PipelineConfig is fully initialized, it should be as immutable as possible (frozen=True). All underlying dataclass fields should have been initialized to their default values, be it user specified via some CLI flag, config file, environment variable, or internally set to a reasonable default. **Parameters:** **kwargs** (`Any` ) ### `ce_delay_ms` {#max.pipelines.lib.config.PipelineConfig.ce_delay_ms} > ce\_delay\_ms: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 Duration of scheduler sleep prior to starting a prefill batch. This is an experimental flag solely for the TTS scheduler. Do not use unless you know what you are doing. ### `custom_architectures` {#max.pipelines.lib.config.PipelineConfig.custom_architectures} > custom\_architectures: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] A list of custom architecture implementations to register. Each input can either be a raw module name or an import path followed by a colon and the module name. Ex: * my\_module * folder/path/to/import:my\_module Each module must expose an ARCHITECTURES list of architectures to register. ### `draft_model_config` {#max.pipelines.lib.config.PipelineConfig.draft_model_config} > property draft\_model\_config: MAXModelConfig | [None](https://docs.python.org/3/library/constants.html#None) ### `enable_chunked_prefill` {#max.pipelines.lib.config.PipelineConfig.enable_chunked_prefill} > enable\_chunked\_prefill: [bool](https://docs.python.org/3/library/functions.html#bool) = True Enable chunked prefill to split context encoding requests into multiple chunks based on ‘target\_num\_new\_tokens’. ### `enable_echo` {#max.pipelines.lib.config.PipelineConfig.enable_echo} > enable\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether the model should be built with echo capabilities. ### `enable_in_flight_batching` {#max.pipelines.lib.config.PipelineConfig.enable_in_flight_batching} > enable\_in\_flight\_batching: [bool](https://docs.python.org/3/library/functions.html#bool) = False When enabled, prioritizes token generation by batching it with context encoding requests. ### `enable_prioritize_first_decode` {#max.pipelines.lib.config.PipelineConfig.enable_prioritize_first_decode} > enable\_prioritize\_first\_decode: [bool](https://docs.python.org/3/library/functions.html#bool) = False When enabled, the scheduler will always run a TG batch immediately after a CE batch, with the same requests. This may be useful for decreasing time-to-first-chunk latency. This is an experimental flag solely for the TTS scheduler. Do not use unless you know what you are doing. ### `engine` {#max.pipelines.lib.config.PipelineConfig.engine} > engine: PipelineEngine | [None](https://docs.python.org/3/library/constants.html#None) = None Engine backend to use for serving, ‘max’ for the max engine, or ‘huggingface’ as fallback option for improved model coverage. ### `graph_quantization_encoding` {#max.pipelines.lib.config.PipelineConfig.graph_quantization_encoding} > property graph\_quantization\_encoding: [QuantizationEncoding](../graph/quantization.md#max.graph.quantization.QuantizationEncoding) | [None](https://docs.python.org/3/library/constants.html#None) Converts the CLI encoding to a MAX graph quantization encoding. **Returns:** The graph quantization encoding corresponding to the CLI encoding. ### `help()` {#max.pipelines.lib.config.PipelineConfig.help} > static help() Documentation for this config class. Return a dictionary of config options and their descriptions. **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [str](https://docs.python.org/3/library/stdtypes.html#str)] ### `ignore_eos` {#max.pipelines.lib.config.PipelineConfig.ignore_eos} > ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) = False Ignore EOS and continue generating tokens, even when an EOS variable is hit. ### `lora_config` {#max.pipelines.lib.config.PipelineConfig.lora_config} > property lora\_config: LoRAConfig | [None](https://docs.python.org/3/library/constants.html#None) ### `lora_manager` {#max.pipelines.lib.config.PipelineConfig.lora_manager} > property lora\_manager: LoRAManager | [None](https://docs.python.org/3/library/constants.html#None) ### `max_batch_size` {#max.pipelines.lib.config.PipelineConfig.max_batch_size} > max\_batch\_size: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Maximum batch size to execute with the model. This is set to one, to minimize memory consumption for the base case, in which a person is running a local server to test out MAX. For users launching in a server scenario, the expectation is that this value should be set higher based on server capacity. ### `max_ce_batch_size` {#max.pipelines.lib.config.PipelineConfig.max_ce_batch_size} > max\_ce\_batch\_size: [int](https://docs.python.org/3/library/functions.html#int) = 192 Maximum cache size to reserve for a single context encoding batch. The actual limit is the lesser of this and max\_batch\_size. ### `max_length` {#max.pipelines.lib.config.PipelineConfig.max_length} > max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Maximum sequence length of the model. ### `max_new_tokens` {#max.pipelines.lib.config.PipelineConfig.max_new_tokens} > max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) = -1 Maximum number of new tokens to generate during a single inference pass of the model. ### `max_num_steps` {#max.pipelines.lib.config.PipelineConfig.max_num_steps} > max\_num\_steps: [int](https://docs.python.org/3/library/functions.html#int) = -1 The number of steps to run for multi-step scheduling. -1 specifies a default value based on configuration and platform. Ignored for models which are not auto-regressive (e.g. embedding models). ### `max_queue_size_tg` {#max.pipelines.lib.config.PipelineConfig.max_queue_size_tg} > max\_queue\_size\_tg: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Maximum number of requests in decode queue. By default, this is max-batch-size. ### `min_batch_size_tg` {#max.pipelines.lib.config.PipelineConfig.min_batch_size_tg} > min\_batch\_size\_tg: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None Specifies a soft floor on the decode batch size. If the TG batch size is larger than this value, the scheduler will continue to run TG batches. If it falls below, the scheduler will prioritize CE. Note that this is NOT a strict minimum! By default, this is max-queue-size-tg. This is an experimental flag solely for the TTS scheduler. Do not use unless you know what you are doing. ### `model_config` {#max.pipelines.lib.config.PipelineConfig.model_config} > property model\_config: MAXModelConfig ### `pad_to_multiple_of` {#max.pipelines.lib.config.PipelineConfig.pad_to_multiple_of} > pad\_to\_multiple\_of: [int](https://docs.python.org/3/library/functions.html#int) = 2 Pad input tensors to be a multiple of value provided. ### `pdl_level` {#max.pipelines.lib.config.PipelineConfig.pdl_level} > pdl\_level: [str](https://docs.python.org/3/library/stdtypes.html#str) = '0' Level of overlap of kernel launch via programmatic dependent grid control. ### `pipeline_role` {#max.pipelines.lib.config.PipelineConfig.pipeline_role} > pipeline\_role: PipelineRole = 'prefill\_and\_decode' Whether the pipeline should serve both a prefill or decode role or both. ### `pool_embeddings` {#max.pipelines.lib.config.PipelineConfig.pool_embeddings} > pool\_embeddings: [bool](https://docs.python.org/3/library/functions.html#bool) = True Whether to pool embedding outputs. ### `profiling_config` {#max.pipelines.lib.config.PipelineConfig.profiling_config} > property profiling\_config: ProfilingConfig ### `resolve()` {#max.pipelines.lib.config.PipelineConfig.resolve} > resolve() Validates and resolves the config. This method is called after the config is initialized, to ensure that all config fields have been initialized to a valid state. **Return type:** None ### `sampling_config` {#max.pipelines.lib.config.PipelineConfig.sampling_config} > property sampling\_config: SamplingConfig ### `target_num_new_tokens` {#max.pipelines.lib.config.PipelineConfig.target_num_new_tokens} > target\_num\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None The target number of un-encoded tokens to include in each batch. If not set, this will be set to a best-guess optimal value based on model, hardware, and available memory. ### `use_experimental_kernels` {#max.pipelines.lib.config.PipelineConfig.use_experimental_kernels} > use\_experimental\_kernels: [str](https://docs.python.org/3/library/stdtypes.html#str) = 'false' ## `PrependPromptSpeechTokens` {#max.pipelines.lib.config.PrependPromptSpeechTokens} > class max.pipelines.lib.config.PrependPromptSpeechTokens(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `ALWAYS` {#max.pipelines.lib.config.PrependPromptSpeechTokens.ALWAYS} > ALWAYS = 'always' Prepend the prompt speech tokens to all blocks of speech tokens sent to the audio decoder. ### `NEVER` {#max.pipelines.lib.config.PrependPromptSpeechTokens.NEVER} > NEVER = 'never' Never prepend the prompt speech tokens sent to the audio decoder. ### `ONCE` {#max.pipelines.lib.config.PrependPromptSpeechTokens.ONCE} > ONCE = 'once' Prepend the prompt speech tokens to the first block of the audio decoder. --- ## core ## `AudioGenerationRequest` {#max.pipelines.core.AudioGenerationRequest} > class max.pipelines.core.AudioGenerationRequest(id: 'str', index: 'int', model: 'str', input: 'Optional\[str]' = None, audio\_prompt\_tokens: 'list\[int]' = \, audio\_prompt\_transcription: 'str' = '', sampling\_params: 'SamplingParams' = SamplingParams(top\_k=1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=0), \_assistant\_message\_override: 'str | None' = None, prompt: 'Optional\[list\[int] | str]' = None) **Parameters:** * **id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **index** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **model** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **input** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) * **audio\_prompt\_tokens** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **audio\_prompt\_transcription** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **sampling\_params** ([`SamplingParams`](#max.pipelines.core.SamplingParams) ) * **\_assistant\_message\_override** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) * **prompt** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) ### `audio_prompt_tokens` {#max.pipelines.core.AudioGenerationRequest.audio_prompt_tokens} > audio\_prompt\_tokens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] The prompt speech IDs to use for audio generation. ### `audio_prompt_transcription` {#max.pipelines.core.AudioGenerationRequest.audio_prompt_transcription} > audio\_prompt\_transcription: [str](https://docs.python.org/3/library/stdtypes.html#str) = '' The audio prompt transcription to use for audio generation. ### `id` {#max.pipelines.core.AudioGenerationRequest.id} > id: [str](https://docs.python.org/3/library/stdtypes.html#str) A unique identifier for the request. This ID can be used to trace and log the request throughout its lifecycle, facilitating debugging and tracking. ### `index` {#max.pipelines.core.AudioGenerationRequest.index} > index: [int](https://docs.python.org/3/library/functions.html#int) The sequence order of this request within a batch. This is useful for maintaining the order of requests when processing multiple requests simultaneously, ensuring that responses can be matched back to their corresponding requests accurately. ### `input` {#max.pipelines.core.AudioGenerationRequest.input} > input: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None The text to generate audio for. The maximum length is 4096 characters. ### `model` {#max.pipelines.core.AudioGenerationRequest.model} > model: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the model to be used for generating audio chunks. This should match the available models on the server and determines the behavior and capabilities of the response generation. ### `prompt` {#max.pipelines.core.AudioGenerationRequest.prompt} > prompt: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None Optionally provide a preprocessed list of token ids or a prompt string to pass as input directly into the model. This replaces automatically generating TokenGeneratorRequestMessages given the input, audio prompt tokens, audio prompt transcription fields. ### `sampling_params` {#max.pipelines.core.AudioGenerationRequest.sampling_params} > sampling\_params: [SamplingParams](#max.pipelines.core.SamplingParams) = SamplingParams(top\_k=1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=0) Request sampling configuration options. ## `AudioGenerationResponse` {#max.pipelines.core.AudioGenerationResponse} > class max.pipelines.core.AudioGenerationResponse(final\_status, audio=None) **Parameters:** * **final\_status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) * **audio** (`np.ndarray` `|` `None` ) ### `audio_data` {#max.pipelines.core.AudioGenerationResponse.audio_data} > property audio\_data: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `final_status` {#max.pipelines.core.AudioGenerationResponse.final_status} > property final\_status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) ### `has_audio_data` {#max.pipelines.core.AudioGenerationResponse.has_audio_data} > property has\_audio\_data: [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_done` {#max.pipelines.core.AudioGenerationResponse.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ## `AudioGenerator` {#max.pipelines.core.AudioGenerator} > class max.pipelines.core.AudioGenerator(\*args, \*\*kwargs) Interface for audio generation models. ### `decoder_sample_rate` {#max.pipelines.core.AudioGenerator.decoder_sample_rate} > property decoder\_sample\_rate: [int](https://docs.python.org/3/library/functions.html#int) The sample rate of the decoder. ### `next_chunk()` {#max.pipelines.core.AudioGenerator.next_chunk} > next\_chunk(batch, num\_tokens) Computes the next audio chunk for a single batch. The new speech tokens are saved to the context. The most recently generated audio is return through the AudioGenerationResponse. **Parameters:** * **batch** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `AudioGeneratorContext` `]` ) – Batch of contexts. * **num\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of speech tokens to generate. **Returns:** Dictionary mapping request IDs to audio generation responses. **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [AudioGenerationResponse](#max.pipelines.core.AudioGenerationResponse)] ### `release()` {#max.pipelines.core.AudioGenerator.release} > release(context) Releases resources associated with this context. **Parameters:** **context** (`AudioGeneratorContext` ) – Finished context. **Return type:** None ## `AudioGeneratorOutput` {#max.pipelines.core.AudioGeneratorOutput} > class max.pipelines.core.AudioGeneratorOutput(audio\_data: 'torch.Tensor', metadata: 'dict\[str, Any]', is\_done: 'bool') **Parameters:** * **audio\_data** (`torch.Tensor` ) * **metadata** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `Any` `]` ) * **is\_done** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `audio_data` {#max.pipelines.core.AudioGeneratorOutput.audio_data} > audio\_data: torch.Tensor ### `is_done` {#max.pipelines.core.AudioGeneratorOutput.is_done} > is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `metadata` {#max.pipelines.core.AudioGeneratorOutput.metadata} > metadata: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), Any] ## `EmbeddingsGenerator` {#max.pipelines.core.EmbeddingsGenerator} > class max.pipelines.core.EmbeddingsGenerator(\*args, \*\*kwargs) Interface for LLM embeddings-generator models. ### `encode()` {#max.pipelines.core.EmbeddingsGenerator.encode} > encode(batch) Computes embeddings for a batch of inputs. **Parameters:** **batch** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `EmbeddingsGeneratorContext` `]` ) – Batch of contexts to generate embeddings for. **Returns:** Dictionary mapping request IDs to their corresponding embeddings. Each embedding is typically a numpy array or tensor of floating point values. **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), Any] ## `EmbeddingsResponse` {#max.pipelines.core.EmbeddingsResponse} > class max.pipelines.core.EmbeddingsResponse(embeddings) Container for the response from embeddings pipeline. **Parameters:** **embeddings** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) ### `embeddings` {#max.pipelines.core.EmbeddingsResponse.embeddings} > embeddings: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ## `InputContext` {#max.pipelines.core.InputContext} > class max.pipelines.core.InputContext(\*args, \*\*kwargs) A base class for model contexts, represent model inputs for TokenGenerators. Token array layout: ```default . +---------- full prompt ----------+ CHUNK_SIZE*N v . +--------------------+---------------+-----------------+----------------+ . | completed | next_tokens | | preallocated | . +--------------------+---------------+-----------------+----------------+ . start_idx ^ active_idx ^ end_idx ^ ``` * completed: The tokens that have already been processed and encoded. * next\_tokens: The tokens that will be processed in the next iteration. This may be a subset of the full prompt due to chunked prefill. * preallocated: The token slots that have been preallocated. The token array resizes to multiples of CHUNK\_SIZE to accommodate the new tokens. ### `active_idx` {#max.pipelines.core.InputContext.active_idx} > property active\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `active_length` {#max.pipelines.core.InputContext.active_length} > property active\_length: [int](https://docs.python.org/3/library/functions.html#int) num tokens input this iteration. This will be the prompt size for context encoding, and simply 1 for token generation. **Type:** Current sequence length ### `all_tokens` {#max.pipelines.core.InputContext.all_tokens} > property all\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) All prompt and generated tokens in the context. ### `assign_to_cache()` {#max.pipelines.core.InputContext.assign_to_cache} > assign\_to\_cache(cache\_seq\_id) Assigns the context to a cache slot. **Parameters:** **cache\_seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `bump_token_indices()` {#max.pipelines.core.InputContext.bump_token_indices} > bump\_token\_indices(start\_idx=0, active\_idx=0, end\_idx=0, committed\_idx=0) Update the start\_idx, active\_idx and end\_idx without manipulating the token array. **Parameters:** * **start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `cache_seq_id` {#max.pipelines.core.InputContext.cache_seq_id} > property cache\_seq\_id: [int](https://docs.python.org/3/library/functions.html#int) Returns the cache slot assigned to the context, raising an error if not assigned. ### `committed_idx` {#max.pipelines.core.InputContext.committed_idx} > property committed\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `compute_num_available_steps()` {#max.pipelines.core.InputContext.compute_num_available_steps} > compute\_num\_available\_steps(max\_seq\_len) Compute the max number of steps we can execute for a given context without exceeding the max\_seq\_len. **Parameters:** **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `current_length` {#max.pipelines.core.InputContext.current_length} > property current\_length: [int](https://docs.python.org/3/library/functions.html#int) The current length of the sequence, including completed and active tokens. ### `end_idx` {#max.pipelines.core.InputContext.end_idx} > property end\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `eos_token_ids` {#max.pipelines.core.InputContext.eos_token_ids} > property eos\_token\_ids: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] ### `generated_tokens` {#max.pipelines.core.InputContext.generated_tokens} > property generated\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) All generated tokens in the context. ### `get_min_token_logit_mask()` {#max.pipelines.core.InputContext.get_min_token_logit_mask} > get\_min\_token\_logit\_mask(num\_steps) Returns a set of indices for the tokens in the output that should be masked. This is primarily used for the min\_tokens setting, where we mask eos tokens in the logits to avoid generating them before we reach min\_tokens. **Parameters:** **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[*int32*]]] ### `is_assigned_to_cache` {#max.pipelines.core.InputContext.is_assigned_to_cache} > property is\_assigned\_to\_cache: [bool](https://docs.python.org/3/library/functions.html#bool) Returns True if input is assigned to a cache slot, False otherwise. ### `is_ce` {#max.pipelines.core.InputContext.is_ce} > property is\_ce: [bool](https://docs.python.org/3/library/functions.html#bool) Returns True if the context is a context encoding context, False otherwise. ### `is_done` {#max.pipelines.core.InputContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_initial_prompt` {#max.pipelines.core.InputContext.is_initial_prompt} > property is\_initial\_prompt: [bool](https://docs.python.org/3/library/functions.html#bool) Returns true if the context has not been updated with tokens. ### `json_schema` {#max.pipelines.core.InputContext.json_schema} > property json\_schema: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) A json schema to use during constrained decoding. ### `jump_ahead()` {#max.pipelines.core.InputContext.jump_ahead} > jump\_ahead(new\_token) Updates the token array, while ensuring the new token is returned to the user. **Parameters:** **new\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `log_probabilities` {#max.pipelines.core.InputContext.log_probabilities} > property log\_probabilities: [int](https://docs.python.org/3/library/functions.html#int) When > 0, returns the log probabilities for the top N tokens for each element token in the sequence. ### `log_probabilities_echo` {#max.pipelines.core.InputContext.log_probabilities_echo} > property log\_probabilities\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) When True, the input tokens are added to the returned logprobs. ### `matcher` {#max.pipelines.core.InputContext.matcher} > property matcher: xgr.GrammarMatcher | [None](https://docs.python.org/3/library/constants.html#None) An optional xgr Grammar Matcher provided when using structured output. ### `max_length` {#max.pipelines.core.InputContext.max_length} > property max\_length: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) The maximum length of this sequence. ### `min_tokens` {#max.pipelines.core.InputContext.min_tokens} > property min\_tokens: [int](https://docs.python.org/3/library/functions.html#int) The minimum number of new tokens to generate. ### `next_tokens` {#max.pipelines.core.InputContext.next_tokens} > property next\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) The next prompt tokens to be input during this iteration. This should be a 1D array of tokens of length active\_length. ### `outstanding_completion_tokens()` {#max.pipelines.core.InputContext.outstanding_completion_tokens} > outstanding\_completion\_tokens() Return the list of outstanding completion tokens and log probabilities that must be returned to the user. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [*LogProbabilities*](#max.pipelines.core.LogProbabilities) | None]] ### `prompt_tokens` {#max.pipelines.core.InputContext.prompt_tokens} > property prompt\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) Prompt tokens in the context. ### `reset()` {#max.pipelines.core.InputContext.reset} > reset() Resets the context’s state by combining all tokens into a new prompt. This method is used when a request is evicted, meaning that the context needed to be re-encoded in the following CE iteration. **Return type:** None ### `rollback()` {#max.pipelines.core.InputContext.rollback} > rollback(idx) Rollback and remove the last idx tokens. **Parameters:** **idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `sampling_params` {#max.pipelines.core.InputContext.sampling_params} > property sampling\_params: [SamplingParams](#max.pipelines.core.SamplingParams) Returns the per-request sampling configuration ### `set_draft_offset()` {#max.pipelines.core.InputContext.set_draft_offset} > set\_draft\_offset(idx) **Parameters:** **idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `set_matcher()` {#max.pipelines.core.InputContext.set_matcher} > set\_matcher(matcher) Set a grammar matcher for use during constrained decoding. **Parameters:** **matcher** (`xgr.GrammarMatcher` ) **Return type:** None ### `set_token_indices()` {#max.pipelines.core.InputContext.set_token_indices} > set\_token\_indices(start\_idx=None, active\_idx=None, end\_idx=None, committed\_idx=None) Set the token indices without manipulating the token array. **Parameters:** * **start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) **Return type:** None ### `start_idx` {#max.pipelines.core.InputContext.start_idx} > property start\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `status` {#max.pipelines.core.InputContext.status} > property status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) ### `tokens` {#max.pipelines.core.InputContext.tokens} > property tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) All tokens (including padded tokens) in the context. In most scenarios, use all\_tokens to get the active full token array. ### `unassign_from_cache()` {#max.pipelines.core.InputContext.unassign_from_cache} > unassign\_from\_cache() Unassigns the context from a cache slot. **Return type:** None ### `update()` {#max.pipelines.core.InputContext.update} > update(new\_token, log\_probabilities=None) Updates the next\_tokens and extends existing tokens to include all generated tokens. **Parameters:** * **new\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **log\_probabilities** ([`LogProbabilities`](#max.pipelines.core.LogProbabilities) `|` `None` ) **Return type:** None ### `update_status()` {#max.pipelines.core.InputContext.update_status} > update\_status(status) **Parameters:** **status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) **Return type:** None ## `LogProbabilities` {#max.pipelines.core.LogProbabilities} > class max.pipelines.core.LogProbabilities(token\_log\_probabilities, top\_log\_probabilities) Log probabilities for an individual output token. **Parameters:** * **token\_log\_probabilities** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`float`](https://docs.python.org/3/library/functions.html#float) `]` ) * **top\_log\_probabilities** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`float`](https://docs.python.org/3/library/functions.html#float) `]` `]` ) ### `token_log_probabilities` {#max.pipelines.core.LogProbabilities.token_log_probabilities} > token\_log\_probabilities Probabilities of each token. **Type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[float](https://docs.python.org/3/library/functions.html#float)] ### `top_log_probabilities` {#max.pipelines.core.LogProbabilities.top_log_probabilities} > top\_log\_probabilities Top tokens and their corresponding probabilities. **Type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[int](https://docs.python.org/3/library/functions.html#int), [float](https://docs.python.org/3/library/functions.html#float)]] ## `PipelineTask` {#max.pipelines.core.PipelineTask} > class max.pipelines.core.PipelineTask(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `AUDIO_GENERATION` {#max.pipelines.core.PipelineTask.AUDIO_GENERATION} > AUDIO\_GENERATION = 'audio\_generation' ### `EMBEDDINGS_GENERATION` {#max.pipelines.core.PipelineTask.EMBEDDINGS_GENERATION} > EMBEDDINGS\_GENERATION = 'embeddings\_generation' ### `SPEECH_TOKEN_GENERATION` {#max.pipelines.core.PipelineTask.SPEECH_TOKEN_GENERATION} > SPEECH\_TOKEN\_GENERATION = 'speech\_token\_generation' ### `TEXT_GENERATION` {#max.pipelines.core.PipelineTask.TEXT_GENERATION} > TEXT\_GENERATION = 'text\_generation' ## `PipelineTokenizer` {#max.pipelines.core.PipelineTokenizer} > class max.pipelines.core.PipelineTokenizer(\*args, \*\*kwargs) Interface for LLM tokenizers. ### `decode()` {#max.pipelines.core.PipelineTokenizer.decode} > async decode(context, encoded, \*\*kwargs) Decodes response tokens to text. **Parameters:** * **context** (`TokenGeneratorContext` ) – Current generation context. * **encoded** (`TokenizerEncoded` ) – Encoded response tokens. **Returns:** Un-encoded response text. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `encode()` {#max.pipelines.core.PipelineTokenizer.encode} > async encode(prompt, add\_special\_tokens) Encodes text prompts as tokens. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Un-encoded prompt text. * **add\_special\_tokens** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the prompt exceeds the configured maximum length. **Return type:** *TokenizerEncoded* ### `eos` {#max.pipelines.core.PipelineTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) The end of sequence token for this tokenizer. ### `expects_content_wrapping` {#max.pipelines.core.PipelineTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) If true, this tokenizer expects messages to have a content property. Text messages are formatted as: ```json { "type": "text", "content": "text content" } ``` instead of the OpenAI spec: ```json { "type": "text", "text": "text content" } ``` NOTE: Multimodal messages omit the content property. Both `image_urls` and `image` content parts are converted to: ```json { "type": "image" } ``` Their content is provided as byte arrays through the top-level property on the request object, i.e., `PipelineTokenizerRequest.images`. ### `new_context()` {#max.pipelines.core.PipelineTokenizer.new_context} > async new\_context(request) Creates a new context from a request object. This is sent to the worker process once and then cached locally. **Parameters:** **request** (`PipelineTokenizerRequest` ) – Incoming request. **Returns:** Initialized context. **Return type:** TokenGeneratorContext ## `SamplingParams` {#max.pipelines.core.SamplingParams} > class max.pipelines.core.SamplingParams(top\_k=1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=0) Request Specific Sampling Parameters that are only known at run time. **Parameters:** * **top\_k** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **top\_p** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **min\_p** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **temperature** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **frequency\_penalty** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **presence\_penalty** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **repetition\_penalty** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **max\_new\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **min\_new\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **ignore\_eos** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **stop** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` `|` `None` ) * **stop\_token\_ids** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` `None` ) * **detokenize** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **seed** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `detokenize` {#max.pipelines.core.SamplingParams.detokenize} > detokenize: [bool](https://docs.python.org/3/library/functions.html#bool) = True Whether to detokenize the output tokens into text. ### `frequency_penalty` {#max.pipelines.core.SamplingParams.frequency_penalty} > frequency\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 The frequency penalty to apply to the model’s output. A positive value will penalize new tokens based on their frequency in the generated text: tokens will receive a penalty proportional to the count of appearances. ### `ignore_eos` {#max.pipelines.core.SamplingParams.ignore_eos} > ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) = False If True, the response will ignore the EOS token, and continue to generate until the max tokens or a stop string is hit. ### `max_new_tokens` {#max.pipelines.core.SamplingParams.max_new_tokens} > max\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) = None The maximum number of new tokens to generate in the response. If not set, the model may generate tokens until it reaches its internal limits or based on other stopping criteria. ### `min_new_tokens` {#max.pipelines.core.SamplingParams.min_new_tokens} > min\_new\_tokens: [int](https://docs.python.org/3/library/functions.html#int) = 0 The minimum number of tokens to generate in the response. ### `min_p` {#max.pipelines.core.SamplingParams.min_p} > min\_p: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in \[0, 1]. Set to 0 to disable this. ### `presence_penalty` {#max.pipelines.core.SamplingParams.presence_penalty} > presence\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 0.0 The presence penalty to apply to the model’s output. A positive value will penalize new tokens that have already appeared in the generated text at least once by applying a constant penalty. ### `repetition_penalty` {#max.pipelines.core.SamplingParams.repetition_penalty} > repetition\_penalty: [float](https://docs.python.org/3/library/functions.html#float) = 1.0 The repetition penalty to apply to the model’s output. Values > 1 will penalize new tokens that have already appeared in the generated text at least once by dividing the logits by the repetition penalty. ### `seed` {#max.pipelines.core.SamplingParams.seed} > seed: [int](https://docs.python.org/3/library/functions.html#int) = 0 The seed to use for the random number generator. ### `stop` {#max.pipelines.core.SamplingParams.stop} > stop: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of detokenized sequences that can be used as stop criteria when generating a new sequence. ### `stop_token_ids` {#max.pipelines.core.SamplingParams.stop_token_ids} > stop\_token\_ids: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of token ids that are used as stopping criteria when generating a new sequence. ### `temperature` {#max.pipelines.core.SamplingParams.temperature} > temperature: [float](https://docs.python.org/3/library/functions.html#float) = 1 Controls the randomness of the model’s output; higher values produce more diverse responses. ### `top_k` {#max.pipelines.core.SamplingParams.top_k} > top\_k: [int](https://docs.python.org/3/library/functions.html#int) = 1 Limits the sampling to the K most probable tokens. This defaults to 1, which enables greedy sampling. ### `top_p` {#max.pipelines.core.SamplingParams.top_p} > top\_p: [float](https://docs.python.org/3/library/functions.html#float) = 1 Only use the tokens whose cumulative probability within the top\_p threshold. This applies to the top\_k tokens. ## `TTSContext` {#max.pipelines.core.TTSContext} > class max.pipelines.core.TTSContext(audio\_prompt\_tokens=\, prev\_samples\_beyond\_offset=0, \_speech\_token\_size=128, \_speech\_token\_end\_idx=0, \_speech\_tokens=\, \_decoded\_index=0, \_block\_counter=0, \_arrival\_time=\, \_audio\_generation\_status=TextGenerationStatus.ACTIVE, \*, prompt, max\_length, tokens, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=None, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, \_matcher=None, \_status=TextGenerationStatus.ACTIVE, \_cache\_seq\_id=None, \_size=-1, \_start\_idx=0, \_active\_idx=-1, \_end\_idx=-1, \_completion\_start\_idx=-1, \_completion\_end\_idx=-1, \_prompt\_len=-1, \_committed\_idx=0, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0) A context for Text-to-Speech (TTS) model inference. This class extends TextContext to handle speech token generation and management. It maintains buffers for audio prompt tokens and generated speech tokens, along with tracking indices for decoding progress. **Parameters:** * **audio\_prompt\_tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Array of input audio prompt tokens used for voice cloning * **prev\_samples\_beyond\_offset** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_speech\_token\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Size of the speech token buffer, defaults to SPEECH\_TOKEN\_audio\_chunk\_size * **\_speech\_token\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Index marking the end of valid speech tokens * **\_speech\_tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Buffer containing the generated speech tokens * **\_decoded\_index** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Index tracking how many tokens have been decoded to audio * **\_block\_counter** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Counter tracking number of speech token blocks generated * **\_arrival\_time** ([`float`](https://docs.python.org/3/library/functions.html#float) ) * **\_audio\_generation\_status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **max\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **eos\_token\_ids** ([`set`](https://docs.python.org/3/library/stdtypes.html#set) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **eos\_sequences** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `]` ) * **log\_probabilities** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **log\_probabilities\_echo** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **ignore\_eos** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **json\_schema** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) * **sampling\_params** ([`SamplingParams`](#max.pipelines.core.SamplingParams) ) * **\_matcher** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `|` `None` ) * **\_status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) * **\_cache\_seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_completion\_start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_completion\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_prompt\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_log\_probabilities\_data** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`LogProbabilities`](#max.pipelines.core.LogProbabilities) `]` ) * **\_is\_initial\_prompt** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **\_draft\_offset** ([`int`](https://docs.python.org/3/library/functions.html#int) ) ### `audio_generation_status` {#max.pipelines.core.TTSContext.audio_generation_status} > property audio\_generation\_status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) ### `audio_prompt_tokens` {#max.pipelines.core.TTSContext.audio_prompt_tokens} > audio\_prompt\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `block_counter` {#max.pipelines.core.TTSContext.block_counter} > property block\_counter: [int](https://docs.python.org/3/library/functions.html#int) ### `decoded_index` {#max.pipelines.core.TTSContext.decoded_index} > property decoded\_index: [int](https://docs.python.org/3/library/functions.html#int) ### `has_undecoded_speech_tokens()` {#max.pipelines.core.TTSContext.has_undecoded_speech_tokens} > has\_undecoded\_speech\_tokens(exclude\_last\_n=0) Checks whether there are undecoded speech tokens. **Parameters:** **exclude\_last\_n** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Number of tokens to exclude from the end when checking for undecoded tokens. For example, if set to 1, the last token will not be considered when checking for undecoded tokens. **Returns:** True if there are undecoded speech tokens (excluding the last n tokens), False otherwise. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_done` {#max.pipelines.core.TTSContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `next_speech_tokens()` {#max.pipelines.core.TTSContext.next_speech_tokens} > next\_speech\_tokens(audio\_chunk\_size=None, buffer=None) Returns a chunk of the next unseen speech tokens. Calling this function will update the index of the last seen token. **Parameters:** * **audio\_chunk\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The number of speech tokens to return. * **buffer** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The number of previous speech tokens to pass to the audio decoder on each generation step. **Returns:** A tuple of (chunk of speech tokens, buffer). **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), [int](https://docs.python.org/3/library/functions.html#int)] ### `prev_samples_beyond_offset` {#max.pipelines.core.TTSContext.prev_samples_beyond_offset} > prev\_samples\_beyond\_offset: [int](https://docs.python.org/3/library/functions.html#int) ### `set_decoded_index()` {#max.pipelines.core.TTSContext.set_decoded_index} > set\_decoded\_index(idx) **Parameters:** **idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `speech_token_status` {#max.pipelines.core.TTSContext.speech_token_status} > property speech\_token\_status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) Returns the status of the speech token generation. ### `speech_tokens` {#max.pipelines.core.TTSContext.speech_tokens} > property speech\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `status` {#max.pipelines.core.TTSContext.status} > property status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) ### `update_audio_generation_status()` {#max.pipelines.core.TTSContext.update_audio_generation_status} > update\_audio\_generation\_status(status) **Parameters:** **status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) **Return type:** None ### `update_speech_token_status()` {#max.pipelines.core.TTSContext.update_speech_token_status} > update\_speech\_token\_status(status) **Parameters:** **status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) **Return type:** None ### `update_speech_tokens()` {#max.pipelines.core.TTSContext.update_speech_tokens} > update\_speech\_tokens(new\_tokens) Updates the next\_tokens **Parameters:** **new\_tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** None ### `update_status()` {#max.pipelines.core.TTSContext.update_status} > update\_status(status) **Parameters:** **status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) **Return type:** None ## `TextAndVisionContext` {#max.pipelines.core.TextAndVisionContext} > class max.pipelines.core.TextAndVisionContext(\*, prompt, max\_length, tokens, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=None, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, \_matcher=None, \_status=TextGenerationStatus.ACTIVE, \_cache\_seq\_id=None, \_size=-1, \_start\_idx=0, \_active\_idx=-1, \_end\_idx=-1, \_completion\_start\_idx=-1, \_completion\_end\_idx=-1, \_prompt\_len=-1, \_committed\_idx=0, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0, pixel\_values=\, extra\_model\_args=\) A base class for model context, specifically for Vision model variants. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **max\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) * **eos\_token\_ids** ([`set`](https://docs.python.org/3/library/stdtypes.html#set) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **eos\_sequences** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `]` ) * **log\_probabilities** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **log\_probabilities\_echo** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **ignore\_eos** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **json\_schema** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) * **sampling\_params** ([`SamplingParams`](#max.pipelines.core.SamplingParams) ) * **\_matcher** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `|` `None` ) * **\_status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) * **\_cache\_seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_completion\_start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_completion\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_prompt\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **\_log\_probabilities\_data** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`LogProbabilities`](#max.pipelines.core.LogProbabilities) `]` ) * **\_is\_initial\_prompt** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **\_draft\_offset** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **pixel\_values** ([`tuple`](https://docs.python.org/3/library/stdtypes.html#tuple) `[` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `,` `...` `]` ) * **extra\_model\_args** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `]` ) ### `extra_model_args` {#max.pipelines.core.TextAndVisionContext.extra_model_args} > extra\_model\_args: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)] ### `pixel_values` {#max.pipelines.core.TextAndVisionContext.pixel_values} > pixel\_values: [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray), ...] ### `update()` {#max.pipelines.core.TextAndVisionContext.update} > update(new\_token, log\_probabilities=None) Updates the next\_tokens and extends existing tokens to include all generated tokens. **Parameters:** * **new\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **log\_probabilities** ([`LogProbabilities`](#max.pipelines.core.LogProbabilities) `|` `None` ) **Return type:** None ## `TextContext` {#max.pipelines.core.TextContext} > class max.pipelines.core.TextContext(\*, prompt, max\_length, tokens, eos\_token\_ids=\, eos\_sequences=\, log\_probabilities=None, log\_probabilities\_echo=False, ignore\_eos=False, json\_schema=None, sampling\_params=\, \_matcher=None, \_status=TextGenerationStatus.ACTIVE, \_cache\_seq\_id=None, \_size=-1, \_start\_idx=0, \_active\_idx=-1, \_end\_idx=-1, \_completion\_start\_idx=-1, \_completion\_end\_idx=-1, \_prompt\_len=-1, \_committed\_idx=0, \_log\_probabilities\_data=\, \_is\_initial\_prompt=True, \_draft\_offset=0) A base class for model context, specifically for Text model variants. This class manages the state and processing of text generation, including token management, caching, and generation parameters. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – The input prompt as either a string or sequence of token IDs * **max\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Maximum allowed length of the generated sequence * **tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – NumPy array containing the token IDs * **eos\_token\_ids** ([`set`](https://docs.python.org/3/library/stdtypes.html#set) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Set of token IDs that indicate end of sequence * **eos\_sequences** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `]` ) * **log\_probabilities** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – Whether to return token log probabilities (None or int) * **log\_probabilities\_echo** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to return log probabilities for prompt tokens * **ignore\_eos** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether to ignore end of sequence tokens and continue generating * **json\_schema** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) – Optional JSON schema for structured output * **sampling\_params** ([`SamplingParams`](#max.pipelines.core.SamplingParams) ) – Parameters controlling the token sampling strategy * **\_matcher** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `|` `None` ) * **\_status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) – Current generation status (active, finished, etc) * **\_cache\_seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – ID of KV cache slot assigned to this context * **\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Current allocated size of token array * **\_start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Start index of current generation window * **\_active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Current position in token sequence * **\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – End index of valid tokens * **\_completion\_start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Start index of completion tokens * **\_completion\_end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – End index of completion tokens * **\_prompt\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Length of original prompt * **\_committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Index up to which tokens are committed * **\_log\_probabilities\_data** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `,` [`LogProbabilities`](#max.pipelines.core.LogProbabilities) `]` ) – Token log probabilities data * **\_is\_initial\_prompt** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether this is the initial prompt encoding * **\_draft\_offset** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Offset for draft decoding ### `active_idx` {#max.pipelines.core.TextContext.active_idx} > property active\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `active_length` {#max.pipelines.core.TextContext.active_length} > property active\_length: [int](https://docs.python.org/3/library/functions.html#int) num tokens input this iteration. This will be the prompt size for context encoding, and simply 1 (or more) for token generation. **Type:** Current sequence length ### `all_tokens` {#max.pipelines.core.TextContext.all_tokens} > property all\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `assign_to_cache()` {#max.pipelines.core.TextContext.assign_to_cache} > assign\_to\_cache(cache\_seq\_id) Assigns this context to a cache slot. The cache slot is used to store and retrieve KV-cache entries for this context during token generation. **Parameters:** **cache\_seq\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The ID of the cache slot to assign this context to. **Raises:** [**RuntimeError**](https://docs.python.org/3/library/exceptions.html#RuntimeError) – If this context is already assigned to a cache slot. **Return type:** None ### `bump_token_indices()` {#max.pipelines.core.TextContext.bump_token_indices} > bump\_token\_indices(start\_idx=0, active\_idx=0, end\_idx=0, committed\_idx=0) Update the start\_idx, active\_idx and end\_idx without manipulating the token array. **Parameters:** * **start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `cache_seq_id` {#max.pipelines.core.TextContext.cache_seq_id} > property cache\_seq\_id: [int](https://docs.python.org/3/library/functions.html#int) Gets the ID of the cache slot this context is assigned to. The cache\_seq\_id is used to look up KV-cache entries for this context during token generation. **Returns:** The cache slot ID. **Return type:** [int](https://docs.python.org/3/library/functions.html#int) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If this context is not currently assigned to a cache slot. ### `committed_idx` {#max.pipelines.core.TextContext.committed_idx} > property committed\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `compute_num_available_steps()` {#max.pipelines.core.TextContext.compute_num_available_steps} > compute\_num\_available\_steps(max\_seq\_len) Compute the max number of steps we can execute for a given context without exceeding the max\_seq\_len. **Parameters:** **max\_seq\_len** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `current_length` {#max.pipelines.core.TextContext.current_length} > property current\_length: [int](https://docs.python.org/3/library/functions.html#int) The current length of the sequence, including completed and active tokens. ### `end_idx` {#max.pipelines.core.TextContext.end_idx} > property end\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `eos_sequences` {#max.pipelines.core.TextContext.eos_sequences} > eos\_sequences: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[list](https://docs.python.org/3/library/stdtypes.html#list)\[[int](https://docs.python.org/3/library/functions.html#int)]] ### `eos_token_ids` {#max.pipelines.core.TextContext.eos_token_ids} > eos\_token\_ids: [set](https://docs.python.org/3/library/stdtypes.html#set)\[[int](https://docs.python.org/3/library/functions.html#int)] ### `generated_tokens` {#max.pipelines.core.TextContext.generated_tokens} > property generated\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) Returns all tokens that have been generated after the prompt. **Returns:** Array of generated tokens from prompt\_len to end\_idx. **Return type:** np.ndarray ### `get_min_token_logit_mask()` {#max.pipelines.core.TextContext.get_min_token_logit_mask} > get\_min\_token\_logit\_mask(num\_steps) Returns a set of indices for the tokens in the output that should be masked. This is primarily used for the min\_tokens setting, where we mask eos tokens in the logits to avoid generating them before we reach min\_tokens. **Returns:** A set of indices for the tokens in the output that should be masked. **Parameters:** **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray)\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any), [*dtype*](https://numpy.org/doc/stable/reference/generated/numpy.dtype.html#numpy.dtype)\[*int32*]]] ### `ignore_eos` {#max.pipelines.core.TextContext.ignore_eos} > ignore\_eos: [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_assigned_to_cache` {#max.pipelines.core.TextContext.is_assigned_to_cache} > property is\_assigned\_to\_cache: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether this context is currently assigned to a cache slot. The cache assignment status indicates whether this context can currently access KV-cache entries for token generation. **Returns:** True if assigned to a cache slot, False otherwise. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_ce` {#max.pipelines.core.TextContext.is_ce} > property is\_ce: [bool](https://docs.python.org/3/library/functions.html#bool) Returns whether this context is in context encoding (CE) mode. CE mode indicates that the context has more than one active token to process, typically during the initial encoding of a prompt or after a rollback. **Returns:** True if in CE mode (active\_length > 1), False otherwise. **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_done` {#max.pipelines.core.TextContext.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `is_initial_prompt` {#max.pipelines.core.TextContext.is_initial_prompt} > property is\_initial\_prompt: [bool](https://docs.python.org/3/library/functions.html#bool) Returns true if the context has not been updated with tokens. ### `json_schema` {#max.pipelines.core.TextContext.json_schema} > json\_schema: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) ### `jump_ahead()` {#max.pipelines.core.TextContext.jump_ahead} > jump\_ahead(new\_token) Updates the token array, while ensuring the new token is returned to the user. **Parameters:** **new\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `log_probabilities` {#max.pipelines.core.TextContext.log_probabilities} > log\_probabilities: [int](https://docs.python.org/3/library/functions.html#int) | [None](https://docs.python.org/3/library/constants.html#None) ### `log_probabilities_echo` {#max.pipelines.core.TextContext.log_probabilities_echo} > log\_probabilities\_echo: [bool](https://docs.python.org/3/library/functions.html#bool) ### `matcher` {#max.pipelines.core.TextContext.matcher} > property matcher: xgr.GrammarMatcher | [None](https://docs.python.org/3/library/constants.html#None) ### `max_length` {#max.pipelines.core.TextContext.max_length} > max\_length: [int](https://docs.python.org/3/library/functions.html#int) ### `min_tokens` {#max.pipelines.core.TextContext.min_tokens} > property min\_tokens: [int](https://docs.python.org/3/library/functions.html#int) The minimum number of new tokens to generate. ### `next_tokens` {#max.pipelines.core.TextContext.next_tokens} > property next\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) Returns the tokens between start\_idx and active\_idx. **Returns:** Array of tokens that have been generated but not yet processed. **Return type:** np.ndarray ### `outstanding_completion_tokens()` {#max.pipelines.core.TextContext.outstanding_completion_tokens} > outstanding\_completion\_tokens() Return the list of outstanding completion tokens and log probabilities that must be returned to the user. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[int](https://docs.python.org/3/library/functions.html#int), [*LogProbabilities*](#max.pipelines.core.LogProbabilities) | None]] ### `prompt` {#max.pipelines.core.TextContext.prompt} > prompt: [str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)] ### `prompt_tokens` {#max.pipelines.core.TextContext.prompt_tokens} > property prompt\_tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) Returns the original prompt tokens. **Returns:** Array of tokens from the initial prompt. **Return type:** np.ndarray ### `reset()` {#max.pipelines.core.TextContext.reset} > reset() Resets the context’s state by combining all tokens into a new prompt. **Return type:** None ### `rollback()` {#max.pipelines.core.TextContext.rollback} > rollback(idx) **Parameters:** **idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** None ### `sampling_params` {#max.pipelines.core.TextContext.sampling_params} > sampling\_params: [SamplingParams](#max.pipelines.core.SamplingParams) ### `set_draft_offset()` {#max.pipelines.core.TextContext.set_draft_offset} > set\_draft\_offset(idx) Sets the draft offset index used for speculative decoding. **Parameters:** **idx** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The index to set as the draft offset. **Return type:** None ### `set_matcher()` {#max.pipelines.core.TextContext.set_matcher} > set\_matcher(matcher) **Parameters:** **matcher** (`xgr.GrammarMatcher` ) **Return type:** None ### `set_token_indices()` {#max.pipelines.core.TextContext.set_token_indices} > set\_token\_indices(start\_idx=None, active\_idx=None, end\_idx=None, committed\_idx=None) Set the token indices without manipulating the token array. **Parameters:** * **start\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **active\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **end\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **committed\_idx** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) **Return type:** None ### `start_idx` {#max.pipelines.core.TextContext.start_idx} > property start\_idx: [int](https://docs.python.org/3/library/functions.html#int) ### `status` {#max.pipelines.core.TextContext.status} > property status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) ### `tokens` {#max.pipelines.core.TextContext.tokens} > tokens: [ndarray](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `unassign_from_cache()` {#max.pipelines.core.TextContext.unassign_from_cache} > unassign\_from\_cache() Unassigns this context from its current cache slot. This clears the cache\_seq\_id, allowing the cache slot to be reused by other contexts. Should be called when the context is no longer actively generating tokens. **Return type:** None ### `update()` {#max.pipelines.core.TextContext.update} > update(new\_token, log\_probabilities=None) Updates the next\_tokens and extends existing tokens to include all generated tokens. **Parameters:** * **new\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **log\_probabilities** ([`LogProbabilities`](#max.pipelines.core.LogProbabilities) `|` `None` ) **Return type:** None ### `update_status()` {#max.pipelines.core.TextContext.update_status} > update\_status(status) **Parameters:** **status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) **Return type:** None ## `TextGenerationResponse` {#max.pipelines.core.TextGenerationResponse} > class max.pipelines.core.TextGenerationResponse(tokens, final\_status) **Parameters:** * **tokens** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TextResponse`](#max.pipelines.core.TextResponse) `]` ) * **final\_status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) ### `append_token()` {#max.pipelines.core.TextGenerationResponse.append_token} > append\_token(token) **Parameters:** **token** ([`TextResponse`](#max.pipelines.core.TextResponse) ) **Return type:** None ### `final_status` {#max.pipelines.core.TextGenerationResponse.final_status} > property final\_status: [TextGenerationStatus](#max.pipelines.core.TextGenerationStatus) ### `is_done` {#max.pipelines.core.TextGenerationResponse.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ### `tokens` {#max.pipelines.core.TextGenerationResponse.tokens} > property tokens: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TextResponse](#max.pipelines.core.TextResponse)] ### `update_status()` {#max.pipelines.core.TextGenerationResponse.update_status} > update\_status(status) **Parameters:** **status** ([`TextGenerationStatus`](#max.pipelines.core.TextGenerationStatus) ) **Return type:** None ## `TextGenerationStatus` {#max.pipelines.core.TextGenerationStatus} > class max.pipelines.core.TextGenerationStatus(value, names=\, \*values, module=None, qualname=None, type=None, start=1, boundary=None) ### `ACTIVE` {#max.pipelines.core.TextGenerationStatus.ACTIVE} > ACTIVE = 'active' ### `END_OF_SEQUENCE` {#max.pipelines.core.TextGenerationStatus.END_OF_SEQUENCE} > END\_OF\_SEQUENCE = 'end\_of\_sequence' ### `MAXIMUM_LENGTH` {#max.pipelines.core.TextGenerationStatus.MAXIMUM_LENGTH} > MAXIMUM\_LENGTH = 'maximum\_length' ### `is_done` {#max.pipelines.core.TextGenerationStatus.is_done} > property is\_done: [bool](https://docs.python.org/3/library/functions.html#bool) ## `TextResponse` {#max.pipelines.core.TextResponse} > class max.pipelines.core.TextResponse(next\_token, log\_probabilities=None) A base class for model response, specifically for Text model variants. **Parameters:** * **next\_token** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` [`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **log\_probabilities** ([`LogProbabilities`](#max.pipelines.core.LogProbabilities) `|` `None` ) ### `next_token` {#max.pipelines.core.TextResponse.next_token} > next\_token Encoded predicted next token. **Type:** [int](https://docs.python.org/3/library/functions.html#int) | [str](https://docs.python.org/3/library/stdtypes.html#str) ### `log_probabilities` {#max.pipelines.core.TextResponse.log_probabilities} > log\_probabilities Log probabilities of each output token. **Type:** [LogProbabilities](#max.pipelines.core.LogProbabilities) | None ## `TokenGenerator` {#max.pipelines.core.TokenGenerator} > class max.pipelines.core.TokenGenerator(\*args, \*\*kwargs) Interface for LLM token-generator models. ### `next_token()` {#max.pipelines.core.TokenGenerator.next_token} > next\_token(batch, num\_steps) Computes the next token response for a single batch. **Parameters:** * **batch** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `TokenGeneratorContext` `]` ) – Batch of contexts. * **int** (`num_steps` ) – Number of tokens to generate. * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Returns:** List of encoded responses (indexed by req. ID) **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [TextResponse](#max.pipelines.core.TextResponse)]] ### `release()` {#max.pipelines.core.TokenGenerator.release} > release(context) Releases resources associated with this context. **Parameters:** **context** (`TokenGeneratorContext` ) – Finished context. **Return type:** None ## `TokenGeneratorRequest` {#max.pipelines.core.TokenGeneratorRequest} > class max.pipelines.core.TokenGeneratorRequest(id: 'str', index: 'int', model\_name: 'str', prompt: 'Union\[str, Sequence\[int], None]' = None, messages: 'Optional\[list\[TokenGeneratorRequestMessage]]' = None, images: 'Optional\[list\[bytes]]' = None, tools: 'Optional\[list\[TokenGeneratorRequestTool]]' = None, response\_format: 'Optional\[TokenGeneratorResponseFormat]' = None, timestamp\_ns: 'int' = 0, request\_path: 'str' = '/', logprobs: 'int' = 0, echo: 'bool' = False, stop: 'Optional\[Union\[str, list\[str]]]' = None, chat\_template\_options: 'Optional\[dict\[str, Any]]' = None, sampling\_params: 'SamplingParams' = SamplingParams(top\_k=1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=0)) **Parameters:** * **id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **index** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **model\_name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` `|` `None` ) * **messages** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TokenGeneratorRequestMessage`](#max.pipelines.core.TokenGeneratorRequestMessage) `]` `|` `None` ) * **images** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`bytes`](https://docs.python.org/3/library/stdtypes.html#bytes) `]` `|` `None` ) * **tools** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TokenGeneratorRequestTool`](#max.pipelines.core.TokenGeneratorRequestTool) `]` `|` `None` ) * **response\_format** ([`TokenGeneratorResponseFormat`](#max.pipelines.core.TokenGeneratorResponseFormat) `|` `None` ) * **timestamp\_ns** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **request\_path** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **logprobs** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **echo** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **stop** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` `|` `None` ) * **chat\_template\_options** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `]` `|` `None` ) * **sampling\_params** ([`SamplingParams`](#max.pipelines.core.SamplingParams) ) ### `chat_template_options` {#max.pipelines.core.TokenGeneratorRequest.chat_template_options} > chat\_template\_options: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)] | [None](https://docs.python.org/3/library/constants.html#None) = None Optional dictionary of options to pass when applying the chat template. ### `echo` {#max.pipelines.core.TokenGeneratorRequest.echo} > echo: [bool](https://docs.python.org/3/library/functions.html#bool) = False If set to True, the response will include the original prompt along with the generated output. This can be useful for debugging or when you want to see how the input relates to the output. ### `id` {#max.pipelines.core.TokenGeneratorRequest.id} > id: [str](https://docs.python.org/3/library/stdtypes.html#str) A unique identifier for the request. This ID can be used to trace and log the request throughout its lifecycle, facilitating debugging and tracking. ### `images` {#max.pipelines.core.TokenGeneratorRequest.images} > images: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[bytes](https://docs.python.org/3/library/stdtypes.html#bytes)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of image byte arrays that can be included as part of the request. This field is optional and may be used for multimodal inputs where images are relevant to the prompt or task. ### `index` {#max.pipelines.core.TokenGeneratorRequest.index} > index: [int](https://docs.python.org/3/library/functions.html#int) The sequence order of this request within a batch. This is useful for maintaining the order of requests when processing multiple requests simultaneously, ensuring that responses can be matched back to their corresponding requests accurately. ### `logprobs` {#max.pipelines.core.TokenGeneratorRequest.logprobs} > logprobs: [int](https://docs.python.org/3/library/functions.html#int) = 0 The number of top log probabilities to return for each generated token. A value of 0 means that log probabilities will not be returned. Useful for analyzing model confidence in its predictions. ### `messages` {#max.pipelines.core.TokenGeneratorRequest.messages} > messages: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TokenGeneratorRequestMessage](#max.pipelines.core.TokenGeneratorRequestMessage)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of messages for chat-based interactions. This is used in chat completion APIs, where each message represents a turn in the conversation. If provided, the model will generate responses based on these messages. ### `model_name` {#max.pipelines.core.TokenGeneratorRequest.model_name} > model\_name: [str](https://docs.python.org/3/library/stdtypes.html#str) The name of the model to be used for generating tokens. This should match the available models on the server and determines the behavior and capabilities of the response generation. ### `prompt` {#max.pipelines.core.TokenGeneratorRequest.prompt} > prompt: [str](https://docs.python.org/3/library/stdtypes.html#str) | [Sequence](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence)\[[int](https://docs.python.org/3/library/functions.html#int)] | [None](https://docs.python.org/3/library/constants.html#None) = None The prompt to be processed by the model. This field supports legacy completion APIs and can accept either a string or a sequence of integers representing token IDs. If not provided, the model may generate output based on the messages field. ### `request_path` {#max.pipelines.core.TokenGeneratorRequest.request_path} > request\_path: [str](https://docs.python.org/3/library/stdtypes.html#str) = '/' The endpoint path for the request. This is typically used for routing and logging requests within the server infrastructure. ### `response_format` {#max.pipelines.core.TokenGeneratorRequest.response_format} > response\_format: [TokenGeneratorResponseFormat](#max.pipelines.core.TokenGeneratorResponseFormat) | [None](https://docs.python.org/3/library/constants.html#None) = None Specifies the desired format for the model’s output. When set, it enables structured generation, which adheres to the json\_schema provided. ### `sampling_params` {#max.pipelines.core.TokenGeneratorRequest.sampling_params} > sampling\_params: [SamplingParams](#max.pipelines.core.SamplingParams) = SamplingParams(top\_k=1, top\_p=1, min\_p=0.0, temperature=1, frequency\_penalty=0.0, presence\_penalty=0.0, repetition\_penalty=1.0, max\_new\_tokens=None, min\_new\_tokens=0, ignore\_eos=False, stop=None, stop\_token\_ids=None, detokenize=True, seed=0) Token sampling configuration parameters for the request. ### `stop` {#max.pipelines.core.TokenGeneratorRequest.stop} > stop: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)] | [None](https://docs.python.org/3/library/constants.html#None) = None //platform.openai.com/docs/api-reference/chat/create#chat-create-stop) **Type:** Optional list of stop expressions (see **Type:** https ### `timestamp_ns` {#max.pipelines.core.TokenGeneratorRequest.timestamp_ns} > timestamp\_ns: [int](https://docs.python.org/3/library/functions.html#int) = 0 The time (in nanoseconds) when the request was received by the server. This can be useful for performance monitoring and logging purposes. ### `tools` {#max.pipelines.core.TokenGeneratorRequest.tools} > tools: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[TokenGeneratorRequestTool](#max.pipelines.core.TokenGeneratorRequestTool)] | [None](https://docs.python.org/3/library/constants.html#None) = None A list of tools that can be invoked during the generation process. This allows the model to utilize external functionalities or APIs to enhance its responses. ## `TokenGeneratorRequestFunction` {#max.pipelines.core.TokenGeneratorRequestFunction} > class max.pipelines.core.TokenGeneratorRequestFunction ### `description` {#max.pipelines.core.TokenGeneratorRequestFunction.description} > description: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `name` {#max.pipelines.core.TokenGeneratorRequestFunction.name} > name: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `parameters` {#max.pipelines.core.TokenGeneratorRequestFunction.parameters} > parameters: [dict](https://docs.python.org/3/library/stdtypes.html#dict) ## `TokenGeneratorRequestMessage` {#max.pipelines.core.TokenGeneratorRequestMessage} > class max.pipelines.core.TokenGeneratorRequestMessage ### `content` {#max.pipelines.core.TokenGeneratorRequestMessage.content} > content: [str](https://docs.python.org/3/library/stdtypes.html#str) | [list](https://docs.python.org/3/library/stdtypes.html#list)\[[dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [Any](https://docs.python.org/3/library/typing.html#typing.Any)]] Content can be simple string or a list of message parts of different modalities. For example: ```json { "role": "user", "content": "What'''s the weather like in Boston today?" } ``` Or: ```json { "role": "user", "content": [ { "type": "text", "text": "What'''s in this image?" }, { "type": "image_url", "image_url": { "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" } } ] } ``` ### `role` {#max.pipelines.core.TokenGeneratorRequestMessage.role} > role: [Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\['system', 'user', 'assistant'] ## `TokenGeneratorRequestTool` {#max.pipelines.core.TokenGeneratorRequestTool} > class max.pipelines.core.TokenGeneratorRequestTool ### `function` {#max.pipelines.core.TokenGeneratorRequestTool.function} > function: [TokenGeneratorRequestFunction](#max.pipelines.core.TokenGeneratorRequestFunction) ### `type` {#max.pipelines.core.TokenGeneratorRequestTool.type} > type: [str](https://docs.python.org/3/library/stdtypes.html#str) ## `TokenGeneratorResponseFormat` {#max.pipelines.core.TokenGeneratorResponseFormat} > class max.pipelines.core.TokenGeneratorResponseFormat ### `json_schema` {#max.pipelines.core.TokenGeneratorResponseFormat.json_schema} > json\_schema: [dict](https://docs.python.org/3/library/stdtypes.html#dict) ### `type` {#max.pipelines.core.TokenGeneratorResponseFormat.type} > type: [str](https://docs.python.org/3/library/stdtypes.html#str) ## `msgpack_numpy_decoder()` {#max.pipelines.core.msgpack_numpy_decoder} > max.pipelines.core.msgpack\_numpy\_decoder(type\_, copy=True) Create a decoder function for the specified type. **Parameters:** * **type** – The type to decode into * **copy** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Copy numpy arrays if true * **type\_** ([`Any`](https://docs.python.org/3/library/typing.html#typing.Any) ) **Returns:** A function that decodes bytes into the specified type **Return type:** [*Callable*](https://docs.python.org/3/library/typing.html#typing.Callable)\[\[[bytes](https://docs.python.org/3/library/stdtypes.html#bytes)], [*Any*](https://docs.python.org/3/library/typing.html#typing.Any)] ## `msgpack_numpy_encoder()` {#max.pipelines.core.msgpack_numpy_encoder} > max.pipelines.core.msgpack\_numpy\_encoder() Create an encoder function that handles numpy arrays. **Returns:** A function that encodes objects into bytes **Return type:** [*Callable*](https://docs.python.org/3/library/typing.html#typing.Callable)\[\[[*Any*](https://docs.python.org/3/library/typing.html#typing.Any)], [bytes](https://docs.python.org/3/library/stdtypes.html#bytes)] --- ## hf_pipeline Generalized Token Generation Pipeline ## `HFEmbeddingsPipeline` {#max.pipelines.lib.hf_pipeline.HFEmbeddingsPipeline} > class max.pipelines.lib.hf\_pipeline.HFEmbeddingsPipeline(pipeline\_config, torch\_device\_type) Generalized token generator pipeline. **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **torch\_device\_type** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) ### `encode()` {#max.pipelines.lib.hf_pipeline.HFEmbeddingsPipeline.encode} > encode(batch) Encodes a batch of text inputs. **Parameters:** **batch** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`TextContext`](core.md#max.pipelines.core.TextContext) `]` ) **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [*EmbeddingsResponse*](core.md#max.pipelines.core.EmbeddingsResponse)] ### `prepare_initial_token_inputs()` {#max.pipelines.lib.hf_pipeline.HFEmbeddingsPipeline.prepare_initial_token_inputs} > prepare\_initial\_token\_inputs(context\_batch) **Parameters:** **context\_batch** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TextContext`](core.md#max.pipelines.core.TextContext) `]` ) **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[*Tensor*, *Tensor*] ## `HFTextGenerationPipeline` {#max.pipelines.lib.hf_pipeline.HFTextGenerationPipeline} > class max.pipelines.lib.hf\_pipeline.HFTextGenerationPipeline(pipeline\_config, torch\_device\_type) HuggingFace text token generator pipeline. **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **torch\_device\_type** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) ### `next_token()` {#max.pipelines.lib.hf_pipeline.HFTextGenerationPipeline.next_token} > next\_token(batch, num\_steps) Provided a batch, process batch inputs, execute the graph for num\_steps in a multi-step scenario, then decode the tokens holistically and return the list of decoded tokens. **Parameters:** * **batch** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`TextContext`](core.md#max.pipelines.core.TextContext) `]` ) * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [*TextGenerationResponse*](core.md#max.pipelines.core.TextGenerationResponse)] ### `release()` {#max.pipelines.lib.hf_pipeline.HFTextGenerationPipeline.release} > release(context) Releases resources associated with this context. **Parameters:** **context** (`TokenGeneratorContext` ) – Finished context. **Return type:** None --- ## hf_utils Utilities for interacting with HuggingFace Files/Repos. ## `HuggingFaceFile` {#max.pipelines.lib.hf_utils.HuggingFaceFile} > class max.pipelines.lib.hf\_utils.HuggingFaceFile(repo\_id, filename, revision=None) A simple object for tracking Hugging Face model metadata. The repo\_id will frequently be used to load a tokenizer, whereas the filename is used to download model weights. **Parameters:** * **repo\_id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **filename** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) ### `download()` {#max.pipelines.lib.hf_utils.HuggingFaceFile.download} > download(force\_download=False) Download the file and return the file path where the data is saved locally. **Parameters:** **force\_download** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*Path*](https://docs.python.org/3/library/pathlib.html#pathlib.Path) ### `exists()` {#max.pipelines.lib.hf_utils.HuggingFaceFile.exists} > exists() **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `filename` {#max.pipelines.lib.hf_utils.HuggingFaceFile.filename} > filename: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `repo_id` {#max.pipelines.lib.hf_utils.HuggingFaceFile.repo_id} > repo\_id: [str](https://docs.python.org/3/library/stdtypes.html#str) ### `revision` {#max.pipelines.lib.hf_utils.HuggingFaceFile.revision} > revision: [str](https://docs.python.org/3/library/stdtypes.html#str) | [None](https://docs.python.org/3/library/constants.html#None) = None ### `size()` {#max.pipelines.lib.hf_utils.HuggingFaceFile.size} > size() **Return type:** [int](https://docs.python.org/3/library/functions.html#int) | None ## `HuggingFaceRepo` {#max.pipelines.lib.hf_utils.HuggingFaceRepo} > class max.pipelines.lib.hf\_utils.HuggingFaceRepo(repo\_id, revision='main', trust\_remote\_code=False, repo\_type=None) A class for interacting with HuggingFace Repos. **Parameters:** * **repo\_id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **trust\_remote\_code** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **repo\_type** (`RepoType` `|` `None` ) ### `download()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.download} > download(filename, force\_download=False) **Parameters:** * **filename** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **force\_download** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*Path*](https://docs.python.org/3/library/pathlib.html#pathlib.Path) ### `encoding_for_file()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.encoding_for_file} > encoding\_for\_file(file) **Parameters:** **file** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) ) **Return type:** *SupportedEncoding* ### `file_exists()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.file_exists} > file\_exists(filename) **Parameters:** **filename** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) ### `files_for_encoding()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.files_for_encoding} > files\_for\_encoding(encoding, weights\_format=None) **Parameters:** * **encoding** (`SupportedEncoding` ) * **weights\_format** ([`WeightsFormat`](../graph/weights.md#max.graph.weights.WeightsFormat) `|` `None` ) **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[*WeightsFormat*](../graph/weights.md#max.graph.weights.WeightsFormat), [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Path*](https://docs.python.org/3/library/pathlib.html#pathlib.Path)]] ### `formats_available` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.formats_available} > property formats\_available: [list](https://docs.python.org/3/library/stdtypes.html#list)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat)] ### `info` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.info} > property info: ModelInfo ### `repo_id` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.repo_id} > repo\_id: [str](https://docs.python.org/3/library/stdtypes.html#str) The HuggingFace repo id. While it’s called repo\_id, it can be a HF remote or local path altogether. ### `repo_type` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.repo_type} > repo\_type: RepoType | [None](https://docs.python.org/3/library/constants.html#None) = None The type of repo. This is inferred from the repo\_id. ### `revision` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.revision} > revision: [str](https://docs.python.org/3/library/stdtypes.html#str) = 'main' The revision to use for the repo. ### `size_of()` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.size_of} > size\_of(filename) **Parameters:** **filename** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) | None ### `supported_encodings` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.supported_encodings} > property supported\_encodings: [list](https://docs.python.org/3/library/stdtypes.html#list)\[SupportedEncoding] ### `trust_remote_code` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.trust_remote_code} > trust\_remote\_code: [bool](https://docs.python.org/3/library/functions.html#bool) = False Whether to trust remote code. ### `weight_files` {#max.pipelines.lib.hf_utils.HuggingFaceRepo.weight_files} > property weight\_files: [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[WeightsFormat](../graph/weights.md#max.graph.weights.WeightsFormat), [list](https://docs.python.org/3/library/stdtypes.html#list)\[[str](https://docs.python.org/3/library/stdtypes.html#str)]] ## `download_weight_files()` {#max.pipelines.lib.hf_utils.download_weight_files} > max.pipelines.lib.hf\_utils.download\_weight\_files(huggingface\_model\_id, filenames, revision=None, force\_download=False, max\_workers=8) Provided a HuggingFace model id, and filenames, download weight files : and return the list of local paths. **Parameters:** * **huggingface\_model\_id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The huggingface model identifier, ie. modularai/Llama-3.1-8B-Instruct-GGUF * **filenames** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` ) – A list of file paths relative to the root of the HuggingFace repo. If files provided are available locally, download is skipped, and the local files are used. * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) – The HuggingFace revision to use. If provided, we check our cache directly without needing to go to HuggingFace directly, saving a network call. * **force\_download** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – A boolean, indicating whether we should force the files to be redownloaded, even if they are already available in our local cache, or a provided path. * **max\_workers** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The number of worker threads to concurrently download files. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*Path*](https://docs.python.org/3/library/pathlib.html#pathlib.Path)] ## `generate_local_model_path()` {#max.pipelines.lib.hf_utils.generate_local_model_path} > max.pipelines.lib.hf\_utils.generate\_local\_model\_path(repo\_id, revision) Generate the local filesystem path where a HuggingFace model repo is cached. This function takes a HuggingFace repository ID and revision hash and returns the full local filesystem path where the model files are cached by the huggingface\_hub library. The path follows the standard HuggingFace caching convention of: \~/.cache/huggingface/hub/models–{org}–{model}/snapshots/{revision} **Parameters:** * **repo\_id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The HuggingFace repository ID in the format “org/model” (e.g. “HuggingFaceTB/SmolLM2-135M”) * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The specific model revision hash to use, typically from a repo lock file **Returns:** The absolute path to the cached model files for the specified revision. For example: “\~/.cache/huggingface/hub/models–HuggingFaceTB–SmolLM2-135M/snapshots/abc123” **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) **Raises:** [**FileNotFoundError**](https://docs.python.org/3/library/exceptions.html#FileNotFoundError) – If the model path does not exist locally ## `repo_exists_with_retry()` {#max.pipelines.lib.hf_utils.repo_exists_with_retry} > max.pipelines.lib.hf\_utils.repo\_exists\_with\_retry(repo\_id, revision) Wrapper around huggingface\_hub.revision\_exists with retry logic. Uses exponential backoff with 25% jitter, starting at 1s and doubling each retry. We use revision\_exists here instead of repo\_exists because repo\_exists does not take in a revision parameter. See huggingface\_hub.revision\_exists for details **Parameters:** * **repo\_id** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) **Return type:** [bool](https://docs.python.org/3/library/functions.html#bool) --- ## pipelines NOTE: These APIs are under heavy development and subject to change. ## Modules * [`architectures`](/max/api/python/pipelines/architectures) * [`config`](/max/api/python/pipelines/config) * [`core`](/max/api/python/pipelines/core) * [`hf_pipeline`](/max/api/python/pipelines/hf_pipeline) * [`hf_utils`](/max/api/python/pipelines/hf_utils) * [`pipeline`](/max/api/python/pipelines/pipeline) * [`registry`](/max/api/python/pipelines/registry) * [`sampling`](/max/api/python/pipelines/sampling) * [`tokenizer`](/max/api/python/pipelines/tokenizer) --- ## log_probabilities ## `compute_log_probabilities_ragged()` {#max.pipelines.lib.log_probabilities.compute_log_probabilities_ragged} > max.pipelines.lib.log\_probabilities.compute\_log\_probabilities\_ragged(\*, input\_row\_offsets, logits, next\_token\_logits, tokens, sampled\_tokens, batch\_top\_n, batch\_echo) Computes the log probabilities for ragged model outputs. **Parameters:** * **input\_row\_offsets** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Token offsets into token-indexed buffers, by batch index. Should have 1 more element than there are batches (batch n is token indices \[input\_row\_offsets\[n], input\_row\_offsets\[n+1])). * **logits** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) `|` `None` ) – (tokens, vocab\_dim) tensor full of tensor logits. Token dimension mapped to batches using input\_row\_offsets. * **next\_token\_logits** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – (batch, vocab\_dim) tensor full of logits for next tokens per batch. * **sampled\_tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – (batch\_dim,) tensor of sampled token per batch * **batch\_top\_n** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Number of top log probabilities to return per input in the batch. For any element where top\_n == 0, the LogProbabilities is skipped. * **batch\_echo** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`bool`](https://docs.python.org/3/library/functions.html#bool) `]` ) – Whether to include input tokens in the returned log probabilities. * **tokens** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Returns:** Computed log probabilities for each item in the batch. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*LogProbabilities*](core.md#max.pipelines.core.LogProbabilities) | None] ## `log_softmax()` {#max.pipelines.lib.log_probabilities.log_softmax} > max.pipelines.lib.log\_probabilities.log\_softmax(x, axis=-1) Compute the logarithm of the softmax function. This implementation uses the identity log(softmax(x)) = x - log(sum(exp(x))) with numerical stability improvements to prevent overflow/underflow. **Parameters:** * **x** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) – Input array * **axis** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – Axis to compute values along **Returns:** Array with same shape as x, representing log(softmax(x)) **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) --- ## pipeline Hugging Face Token Generation Pipeline. ## `FrequencyData` {#max.pipelines.lib.pipeline.FrequencyData} > class max.pipelines.lib.pipeline.FrequencyData(data, offsets) Container for token frequency data in CSR format. **Parameters:** * **data** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **offsets** ([`Tensor`](../driver.md#max.driver.Tensor) ) ### `data` {#max.pipelines.lib.pipeline.FrequencyData.data} > data: [Tensor](../driver.md#max.driver.Tensor) 1D array of the column indices of the : non-zero elements in the matrix. data\[:, 1]: 1D array of the non-zero elements in the : matrix. **Type:** data\[ **Type:** , 0] ### `offsets` {#max.pipelines.lib.pipeline.FrequencyData.offsets} > offsets: [Tensor](../driver.md#max.driver.Tensor) shape \[batch\_size + 1] indicating start of each sequence’s data. **Type:** Row offsets ## `KVCacheMixin` {#max.pipelines.lib.pipeline.KVCacheMixin} > class max.pipelines.lib.pipeline.KVCacheMixin(\*args, \*\*kwargs) ### `estimate_kv_cache_size()` {#max.pipelines.lib.pipeline.KVCacheMixin.estimate_kv_cache_size} > abstract classmethod estimate\_kv\_cache\_size(pipeline\_config, available\_cache\_memory, devices, huggingface\_config, kv\_cache\_config, cache\_dtype) Estimates the size of the kv cache in bytes. **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Device`](../driver.md#max.driver.Device) `]` ) * **huggingface\_config** (`AutoConfig` ) * **kv\_cache\_config** (`KVCacheConfig` ) * **cache\_dtype** ([`DType`](../dtype.md#max.dtype.DType) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `get_kv_params()` {#max.pipelines.lib.pipeline.KVCacheMixin.get_kv_params} > abstract classmethod get\_kv\_params(huggingface\_config, n\_devices, kv\_cache\_config, cache\_dtype) Returns the KV cache params for the pipeline model. **Parameters:** * **huggingface\_config** (`AutoConfig` ) * **n\_devices** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **kv\_cache\_config** (`KVCacheConfig` ) * **cache\_dtype** ([`DType`](../dtype.md#max.dtype.DType) ) **Return type:** [*KVCacheParams*](../nn/kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheParams) ### `get_num_layers()` {#max.pipelines.lib.pipeline.KVCacheMixin.get_num_layers} > abstract classmethod get\_num\_layers(huggingface\_config) Returns the number of layers for the pipeline model. **Parameters:** **huggingface\_config** (`AutoConfig` ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `load_kv_manager()` {#max.pipelines.lib.pipeline.KVCacheMixin.load_kv_manager} > load\_kv\_manager(session, available\_cache\_memory) Provided a PipelineConfig and InferenceSession, loads the KV manager. **Parameters:** * **session** ([`InferenceSession`](../engine.md#max.engine.InferenceSession) ) – Inference session to compile and init the KV cache. * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – Amount of memory available to the KV cache, in bytes. **Returns:** one per input modality. **Return type:** Either a single KV cache manager or a tuple of KV cache managers ## `ModelInputs` {#max.pipelines.lib.pipeline.ModelInputs} > class max.pipelines.lib.pipeline.ModelInputs Base class for model inputs. Use this class to encapsulate inputs for your model. You may store any number of dataclass fields The following example demonstrates how to create a custom inputs class for a model: ```python class ReplitInputs(ModelInputs): tokens: Tensor input_row_offsets: Tensor def __init__(self, tokens: Tensor, input_row_offsets: Tensor): self.tokens = tokens self.input_row_offsets = input_row_offsets tokens = Tensor.zeros((1, 2, 3), DType.int64) input_row_offsets = Tensor.zeros((1, 1, 1), DType.int64) # Initialize inputs inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets) # Access tensors list(inputs) == [tokens, input_row_offsets] # Output: True ``` ### `kv_cache_inputs` {#max.pipelines.lib.pipeline.ModelInputs.kv_cache_inputs} > kv\_cache\_inputs: [KVCacheInputs](../nn/kv_cache/manager.md#max.nn.kv_cache.manager.KVCacheInputs) | [None](https://docs.python.org/3/library/constants.html#None) = None ## `ModelOutputs` {#max.pipelines.lib.pipeline.ModelOutputs} > class max.pipelines.lib.pipeline.ModelOutputs(logits: 'Tensor', next\_token\_logits: 'Tensor | None' = None, logit\_offsets: 'Tensor | None' = None) **Parameters:** * **logits** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **next\_token\_logits** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) * **logit\_offsets** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) ### `logit_offsets` {#max.pipelines.lib.pipeline.ModelOutputs.logit_offsets} > logit\_offsets: [Tensor](../driver.md#max.driver.Tensor) | [None](https://docs.python.org/3/library/constants.html#None) = None Offsets to access variable length logits for each sequence. ### `logits` {#max.pipelines.lib.pipeline.ModelOutputs.logits} > logits: [Tensor](../driver.md#max.driver.Tensor) Logits for a variable number of tokens per sequence. ### `next_token_logits` {#max.pipelines.lib.pipeline.ModelOutputs.next_token_logits} > next\_token\_logits: [Tensor](../driver.md#max.driver.Tensor) | [None](https://docs.python.org/3/library/constants.html#None) = None Logits for just the next token. ## `PipelineModel` {#max.pipelines.lib.pipeline.PipelineModel} > class max.pipelines.lib.pipeline.PipelineModel(pipeline\_config, session, huggingface\_config, encoding, devices, kv\_cache\_config, weights, adapter, return\_logits) A pipeline model with setup, input preparation and execution methods. **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **session** ([`InferenceSession`](../engine.md#max.engine.InferenceSession) ) * **huggingface\_config** (`AutoConfig` ) * **encoding** (`SupportedEncoding` ) * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Device`](../driver.md#max.driver.Device) `]` ) * **kv\_cache\_config** (`KVCacheConfig` ) * **weights** ([`Weights`](../graph/weights.md#max.graph.weights.Weights) ) * **adapter** (`Optional` `[` `WeightsAdapter` `]` ) * **return\_logits** ([`ReturnLogits`](../nn/transformer/transformer.md#max.nn.transformer.transformer.ReturnLogits) ) ### `calculate_max_seq_len()` {#max.pipelines.lib.pipeline.PipelineModel.calculate_max_seq_len} > abstract classmethod calculate\_max\_seq\_len(pipeline\_config, huggingface\_config) Calculate the optimal max sequence length for the model. Models are expected to implement this method. The following example shows how to implement this method for a Mistral model: ```python class MistralModel(PipelineModel): @classmethod def calculate_max_seq_len(cls, pipeline_config, huggingface_config) -> int: try: return upper_bounded_default( upper_bound=huggingface_config.max_seq_len, default=pipeline_config.max_length, ) except ValueError as e: msg = ( "Unable to infer max_length for Mistral, the provided " f"max_length ({pipeline_config.max_length}) exceeds the " f"model's max_seq_len ({huggingface_config.max_seq_len})." ) raise ValueError(msg) from e ``` **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) – Configuration for the pipeline. * **huggingface\_config** (`AutoConfig` ) – Hugging Face model configuration. **Returns:** The maximum sequence length to use. **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `compute_log_probabilities()` {#max.pipelines.lib.pipeline.PipelineModel.compute_log_probabilities} > compute\_log\_probabilities(model\_inputs, model\_outputs, next\_tokens, batch\_top\_n, batch\_echo) Optional method that can be overridden to compute log probabilities. **Parameters:** * **model\_inputs** ([`ModelInputs`](#max.pipelines.lib.pipeline.ModelInputs) ) – Inputs to the model returned by prepare\_\*\_token\_inputs(). * **model\_outputs** ([`ModelOutputs`](#max.pipelines.lib.pipeline.ModelOutputs) ) – Outputs returned by execute(). * **next\_tokens** ([`Tensor`](../driver.md#max.driver.Tensor) ) – Sampled tokens. Should have shape=\[batch size] * **batch\_top\_n** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) – Number of top log probabilities to return per input in the batch. For any element where top\_n == 0, the LogProbabilities is skipped. * **batch\_echo** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`bool`](https://docs.python.org/3/library/functions.html#bool) `]` ) – Whether to include input tokens in the returned log probabilities. **Returns:** List of log probabilities. **Return type:** [list](https://docs.python.org/3/library/stdtypes.html#list)\[[*LogProbabilities*](core.md#max.pipelines.core.LogProbabilities) | None] | None ### `dtype` {#max.pipelines.lib.pipeline.PipelineModel.dtype} > property dtype: [DType](../dtype.md#max.dtype.DType) ### `estimate_weights_size()` {#max.pipelines.lib.pipeline.PipelineModel.estimate_weights_size} > classmethod estimate\_weights\_size(pipeline\_config) Calculates the estimated memory consumption of our model. **Parameters:** **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `execute()` {#max.pipelines.lib.pipeline.PipelineModel.execute} > abstract execute(model\_inputs) Executes the graph with the given inputs. **Parameters:** **model\_inputs** ([`ModelInputs`](#max.pipelines.lib.pipeline.ModelInputs) ) – The model inputs to execute, containing tensors and any other required data for model execution. **Returns:** ModelOutputs containing the pipeline’s output tensors. **Return type:** [*ModelOutputs*](#max.pipelines.lib.pipeline.ModelOutputs) This is an abstract method that must be implemented by concrete PipelineModels to define their specific execution logic. ### `infer_optimal_batch_size()` {#max.pipelines.lib.pipeline.PipelineModel.infer_optimal_batch_size} > classmethod infer\_optimal\_batch\_size(pipeline\_config, available\_cache\_memory, huggingface\_config, devices, kv\_cache\_config, cache\_dtype) Returns the estimated optimal batch size to run the model given current memory constraints. **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **available\_cache\_memory** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **huggingface\_config** (`AutoConfig` ) * **devices** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`Device`](../driver.md#max.driver.Device) `]` ) * **kv\_cache\_config** (`KVCacheConfig` ) * **cache\_dtype** ([`DType`](../dtype.md#max.dtype.DType) ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `prepare_initial_token_inputs()` {#max.pipelines.lib.pipeline.PipelineModel.prepare_initial_token_inputs} > abstract prepare\_initial\_token\_inputs(context\_batch, kv\_cache\_inputs=None, return\_n\_logits=1) Prepares the initial inputs to be passed to .execute(). The inputs and functionality of this method can vary per model. For example, the model inputs could include: * Encoded tensors * A unique IDs for each tensor if this model uses a KV Cache manager. * kv\_cache\_inputs: The kv cache inputs required for the model. This should be None if the model does not use KV Cache. This function would batch the encoded tensors, claim a slot in the kv cache if the ID hasn’t been seen before, and return the inputs and caches as a list of tensors. **Parameters:** * **context\_batch** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` `T` `]` ) * **kv\_cache\_inputs** ([`KVCacheInputs`](../nn/kv_cache/manager.md#max.nn.kv_cache.manager.KVCacheInputs) `|` `None` ) * **return\_n\_logits** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*ModelInputs*](#max.pipelines.lib.pipeline.ModelInputs) ### `prepare_next_token_inputs()` {#max.pipelines.lib.pipeline.PipelineModel.prepare_next_token_inputs} > abstract prepare\_next\_token\_inputs(next\_tokens, prev\_model\_inputs) Prepares the secondary inputs to be passed to .execute(). While prepare\_initial\_token\_inputs is responsible for managing the initial inputs. This function is responsible for updating the inputs, for each step in a multi-step execution pattern. **Parameters:** * **next\_tokens** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **prev\_model\_inputs** ([`ModelInputs`](#max.pipelines.lib.pipeline.ModelInputs) ) **Return type:** [*ModelInputs*](#max.pipelines.lib.pipeline.ModelInputs) ## `TextGenerationPipeline` {#max.pipelines.lib.pipeline.TextGenerationPipeline} > class max.pipelines.lib.pipeline.TextGenerationPipeline(pipeline\_config, pipeline\_model, eos\_token\_id, weight\_adapters) Generalized token generator pipeline. **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **pipeline\_model** ([`type`](https://docs.python.org/3/library/functions.html#type) `[` [`PipelineModel`](#max.pipelines.lib.pipeline.PipelineModel) `]` ) * **eos\_token\_id** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **weight\_adapters** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`WeightsFormat`](../graph/weights.md#max.graph.weights.WeightsFormat) `,` `WeightsAdapter` `]` ) ### `calculate_num_steps()` {#max.pipelines.lib.pipeline.TextGenerationPipeline.calculate_num_steps} > calculate\_num\_steps(num\_steps, context) **Parameters:** * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **context** (`T` ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) ### `next_token()` {#max.pipelines.lib.pipeline.TextGenerationPipeline.next_token} > next\_token(batch, num\_steps) Provided a batch, process batch inputs, execute the graph for num\_steps in a multi-step scenario, then decode the tokens holistically and return the list of decoded tokens. **Parameters:** * **batch** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` `T` `]` ) * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [dict](https://docs.python.org/3/library/stdtypes.html#dict)\[[str](https://docs.python.org/3/library/stdtypes.html#str), [*TextGenerationResponse*](core.md#max.pipelines.core.TextGenerationResponse)] ### `prepare_batch()` {#max.pipelines.lib.pipeline.TextGenerationPipeline.prepare_batch} > prepare\_batch(batch, num\_steps) **Parameters:** * **batch** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` `T` `]` ) * **num\_steps** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[*ModelInputs*](#max.pipelines.lib.pipeline.ModelInputs), [int](https://docs.python.org/3/library/functions.html#int), *Tensor* | None] ### `release()` {#max.pipelines.lib.pipeline.TextGenerationPipeline.release} > release(context) Mark the context as complete, releasing the cache slot from the KV manager. **Parameters:** **context** (`T` ) **Return type:** None ### `sample_logits()` {#max.pipelines.lib.pipeline.TextGenerationPipeline.sample_logits} > sample\_logits(logits, prev\_tokens, top\_k, max\_k, temperature, top\_p, seed, \*, logit\_offsets=None, bitmask=None, frequency\_data=None, min\_tokens\_mask=None, frequency\_penalty=None, presence\_penalty=None, repetition\_penalty=None) **Parameters:** * **logits** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **prev\_tokens** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **top\_k** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **max\_k** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **temperature** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **top\_p** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **seed** ([`Tensor`](../driver.md#max.driver.Tensor) ) * **logit\_offsets** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) * **bitmask** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) * **frequency\_data** ([`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`FrequencyData`](#max.pipelines.lib.pipeline.FrequencyData) `]` `|` `None` ) * **min\_tokens\_mask** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) * **frequency\_penalty** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) * **presence\_penalty** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) * **repetition\_penalty** ([`Tensor`](../driver.md#max.driver.Tensor) `|` `None` ) **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[*Tensor*](../driver.md#max.driver.Tensor), [*Tensor*](../driver.md#max.driver.Tensor)] ## `basicConfig()` {#max.pipelines.lib.pipeline.basicConfig} > max.pipelines.lib.pipeline.basicConfig(\*\*kwargs) ## `get_paged_manager()` {#max.pipelines.lib.pipeline.get_paged_manager} > max.pipelines.lib.pipeline.get\_paged\_manager(pipeline) **Parameters:** **pipeline** ([`TokenGenerator`](core.md#max.pipelines.core.TokenGenerator) ) **Return type:** *PagedKVCacheManager* | None ## `upper_bounded_default()` {#max.pipelines.lib.pipeline.upper_bounded_default} > max.pipelines.lib.pipeline.upper\_bounded\_default(upper\_bound, default) Given an upper bound and an optional default value, returns a final value that cannot exceed the upper bound. **Parameters:** * **default** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) – The default value to use, or None to use the upper bound. * **upper\_bound** ([`int`](https://docs.python.org/3/library/functions.html#int) ) – The upper bound to use. **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the provided default value exceeds the upper bound. **Returns:** The final value. **Return type:** [int](https://docs.python.org/3/library/functions.html#int) --- ## registry Model registry, for tracking various model variants. ## `PipelineRegistry` {#max.pipelines.lib.registry.PipelineRegistry} > class max.pipelines.lib.registry.PipelineRegistry(architectures) **Parameters:** **architectures** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`SupportedArchitecture`](#max.pipelines.lib.registry.SupportedArchitecture) `]` ) ### `get_active_huggingface_config()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_huggingface_config} > get\_active\_huggingface\_config(huggingface\_repo) Retrieves or creates a cached HuggingFace AutoConfig for the given model configuration. This method maintains a cache of HuggingFace configurations to avoid reloading them unnecessarily which incurs a huggingface hub API call. If a config for the given model hasn’t been loaded before, it will create a new one using AutoConfig.from\_pretrained() with the model’s settings. **Parameters:** **huggingface\_repo** ([`HuggingFaceRepo`](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo) ) – The HuggingFaceRepo containing the model. **Returns:** The HuggingFace configuration object for the model. **Return type:** AutoConfig ### `get_active_tokenizer()` {#max.pipelines.lib.registry.PipelineRegistry.get_active_tokenizer} > get\_active\_tokenizer(huggingface\_repo) Retrieves or creates a cached HuggingFace AutoTokenizer for the given model configuration. This method maintains a cache of HuggingFace tokenizers to avoid reloading them unnecessarily which incurs a huggingface hub API call. If a tokenizer for the given model hasn’t been loaded before, it will create a new one using AutoTokenizer.from\_pretrained() with the model’s settings. **Parameters:** **huggingface\_repo** ([`HuggingFaceRepo`](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo) ) – The HuggingFaceRepo containing the model. **Returns:** The HuggingFace tokenizer for the model. **Return type:** PreTrainedTokenizer | PreTrainedTokenizerFast ### `register()` {#max.pipelines.lib.registry.PipelineRegistry.register} > register(architecture, \*, allow\_override=False) Add new architecture to registry. **Parameters:** * **architecture** ([`SupportedArchitecture`](#max.pipelines.lib.registry.SupportedArchitecture) ) * **allow\_override** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** None ### `reset()` {#max.pipelines.lib.registry.PipelineRegistry.reset} > reset() **Return type:** None ### `retrieve()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve} > retrieve(pipeline\_config, task=PipelineTask.TEXT\_GENERATION, override\_architecture=None) **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **task** ([`PipelineTask`](core.md#max.pipelines.core.PipelineTask) ) * **override\_architecture** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[PipelineTokenizer](core.md#max.pipelines.core.PipelineTokenizer), PipelineTypes] ### `retrieve_architecture()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_architecture} > retrieve\_architecture(huggingface\_repo) **Parameters:** **huggingface\_repo** ([`HuggingFaceRepo`](hf_utils.md#max.pipelines.lib.hf_utils.HuggingFaceRepo) ) **Return type:** [*SupportedArchitecture*](#max.pipelines.lib.registry.SupportedArchitecture) | None ### `retrieve_factory()` {#max.pipelines.lib.registry.PipelineRegistry.retrieve_factory} > retrieve\_factory(pipeline\_config, task=PipelineTask.TEXT\_GENERATION, override\_architecture=None) **Parameters:** * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) * **task** ([`PipelineTask`](core.md#max.pipelines.core.PipelineTask) ) * **override\_architecture** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) **Return type:** [tuple](https://docs.python.org/3/library/stdtypes.html#tuple)\[[PipelineTokenizer](core.md#max.pipelines.core.PipelineTokenizer), Callable\[\[], PipelineTypes]] ## `SupportedArchitecture` {#max.pipelines.lib.registry.SupportedArchitecture} > class max.pipelines.lib.registry.SupportedArchitecture(name, example\_repo\_ids, default\_encoding, supported\_encodings, pipeline\_model, task, tokenizer, default\_weights\_format, multi\_gpu\_supported=False, rope\_type=RopeType.none, weight\_adapters=None) Represents a model architecture configuration for MAX pipelines. This class defines all the necessary components and settings required to support a specific model architecture within the MAX pipeline system. Each SupportedArchitecture instance encapsulates the model implementation, tokenizer, supported encodings, and other architecture-specific configuration. New architectures should be registered into the [`PipelineRegistry`](#max.pipelines.lib.registry.PipelineRegistry) using the [`register()`](#max.pipelines.lib.registry.PipelineRegistry.register) method. ```python my_architecture = SupportedArchitecture( name="MyModelForCausalLM", # Must match your Hugging Face model class name example_repo_ids=[ "your-org/your-model-name", # Add example model repository IDs ], default_encoding=SupportedEncoding.q4_k, supported_encodings={ SupportedEncoding.q4_k: [KVCacheStrategy.PAGED], SupportedEncoding.bfloat16: [KVCacheStrategy.PAGED], # Add other encodings your model supports }, pipeline_model=MyModel, tokenizer=TextTokenizer, default_weights_format=WeightsFormat.safetensors, multi_gpu_supported=True, # Set based on your implementation capabilities weight_adapters={ WeightsFormat.safetensors: weight_adapters.convert_safetensor_state_dict, # Add other weight formats if needed }, task=PipelineTask.TEXT_GENERATION, ) ``` **Parameters:** * **name** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – The name of the model architecture that must match the Hugging Face model class name. * **example\_repo\_ids** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `]` ) – A list of Hugging Face repository IDs that use this architecture for testing and validation purposes. * **default\_encoding** (`SupportedEncoding` ) – The default quantization encoding to use when no specific encoding is requested. * **supported\_encodings** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` `SupportedEncoding` `,` [`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`KVCacheStrategy`](../nn/kv_cache/cache_params.md#max.nn.kv_cache.cache_params.KVCacheStrategy) `]` `]` ) – A dictionary mapping supported quantization encodings to their compatible KV cache strategies. * **pipeline\_model** ([`type`](https://docs.python.org/3/library/functions.html#type) `[` [`PipelineModel`](pipeline.md#max.pipelines.lib.pipeline.PipelineModel) `]` ) – The PipelineModel class that defines the model graph structure and execution logic. * **task** ([`PipelineTask`](core.md#max.pipelines.core.PipelineTask) ) – The pipeline task type that this architecture supports. * **tokenizer** (`Callable` `[` `...` `,` [`PipelineTokenizer`](core.md#max.pipelines.core.PipelineTokenizer) `]` ) – A callable that returns a PipelineTokenizer instance for preprocessing model inputs. * **default\_weights\_format** ([`WeightsFormat`](../graph/weights.md#max.graph.weights.WeightsFormat) ) – The weights format expected by the pipeline\_model. * **multi\_gpu\_supported** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) – Whether the architecture supports multi-GPU execution. * **rope\_type** (`RopeType` ) – The type of RoPE (Rotary Position Embedding) used by the model. * **weight\_adapters** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`WeightsFormat`](../graph/weights.md#max.graph.weights.WeightsFormat) `,` `WeightsAdapter` `]` `|` `None` ) – A dictionary of weight format adapters for converting checkpoints from different formats to the default format. ### `tokenizer_cls` {#max.pipelines.lib.registry.SupportedArchitecture.tokenizer_cls} > property tokenizer\_cls: [type](https://docs.python.org/3/library/functions.html#type)\[[PipelineTokenizer](core.md#max.pipelines.core.PipelineTokenizer)] ## `get_pipeline_for_task()` {#max.pipelines.lib.registry.get_pipeline_for_task} > max.pipelines.lib.registry.get\_pipeline\_for\_task(task, pipeline\_config) **Parameters:** * **task** ([`PipelineTask`](core.md#max.pipelines.core.PipelineTask) ) * **pipeline\_config** ([`PipelineConfig`](config.md#max.pipelines.lib.config.PipelineConfig) ) **Return type:** [type](https://docs.python.org/3/library/functions.html#type)\[[TextGenerationPipeline](pipeline.md#max.pipelines.lib.pipeline.TextGenerationPipeline)] | [type](https://docs.python.org/3/library/functions.html#type)\[EmbeddingsPipeline] | [type](https://docs.python.org/3/library/functions.html#type)\[SpeculativeDecodingTextGenerationPipeline] | [type](https://docs.python.org/3/library/functions.html#type)\[AudioGeneratorPipeline] | [type](https://docs.python.org/3/library/functions.html#type)\[SpeechTokenGenerationPipeline] --- ## sampling ## `rejection_sampler()` {#max.pipelines.lib.sampling.rejection_sampler} > max.pipelines.lib.sampling.rejection\_sampler(device, \*, seed=0) **Parameters:** * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) * **seed** ([`int`](https://docs.python.org/3/library/functions.html#int) ) **Return type:** [*Graph*](../graph/Graph.md#max.graph.Graph) ## `token_sampler()` {#max.pipelines.lib.sampling.token_sampler} > max.pipelines.lib.sampling.token\_sampler(sampling\_config, device, return\_logits=False) **Parameters:** * **sampling\_config** (`SamplingConfig` ) * **device** ([`DeviceRef`](../graph/type.md#max.graph.type.DeviceRef) ) * **return\_logits** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*Graph*](../graph/Graph.md#max.graph.Graph) --- ## tokenizer Implementations of provided tokenizers. ## `IdentityPipelineTokenizer` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer} > class max.pipelines.lib.tokenizer.IdentityPipelineTokenizer(\*args, \*\*kwargs) ### `decode()` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.decode} > async decode(context, encoded, \*\*kwargs) Decodes response tokens to text. **Parameters:** * **context** (`TokenGeneratorContext` ) – Current generation context. * **encoded** (`TokenizerEncoded` ) – Encoded response tokens. **Returns:** Un-encoded response text. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `encode()` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.encode} > async encode(prompt, add\_special\_tokens=False) Encodes text prompts as tokens. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Un-encoded prompt text. * **add\_special\_tokens** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the prompt exceeds the configured maximum length. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `eos` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) The end of sequence token for this tokenizer. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.IdentityPipelineTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) If true, this tokenizer expects messages to have a content property. Text messages are formatted as: ```json { "type": "text", "content": "text content" } ``` instead of the OpenAI spec: ```json { "type": "text", "text": "text content" } ``` NOTE: Multimodal messages omit the content property. Both `image_urls` and `image` content parts are converted to: ```json { "type": "image" } ``` Their content is provided as byte arrays through the top-level property on the request object, i.e., `PipelineTokenizerRequest.images`. ## `PreTrainedPipelineTokenizer` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer} > class max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer(delegate) **Parameters:** **delegate** (`Union` `[` `PreTrainedTokenizer` `,` `PreTrainedTokenizerFast` `]` ) ### `apply_chat_template()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.apply_chat_template} > apply\_chat\_template(messages) **Parameters:** **messages** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TokenGeneratorRequestMessage`](core.md#max.pipelines.core.TokenGeneratorRequestMessage) `]` ) **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `decode()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.decode} > async decode(context, encoded, \*\*kwargs) Decodes response tokens to text. **Parameters:** * **context** (`TokenGeneratorContext` ) – Current generation context. * **encoded** (`TokenizerEncoded` ) – Encoded response tokens. **Returns:** Un-encoded response text. **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `encode()` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.encode} > async encode(prompt, add\_special\_tokens=False) Encodes text prompts as tokens. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) – Un-encoded prompt text. * **add\_special\_tokens** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Raises:** [**ValueError**](https://docs.python.org/3/library/exceptions.html#ValueError) – If the prompt exceeds the configured maximum length. **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `eos` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) The end of sequence token for this tokenizer. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.PreTrainedPipelineTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) If true, this tokenizer expects messages to have a content property. Text messages are formatted as: ```json { "type": "text", "content": "text content" } ``` instead of the OpenAI spec: ```json { "type": "text", "text": "text content" } ``` NOTE: Multimodal messages omit the content property. Both `image_urls` and `image` content parts are converted to: ```json { "type": "image" } ``` Their content is provided as byte arrays through the top-level property on the request object, i.e., `PipelineTokenizerRequest.images`. ## `TextAndVisionTokenizer` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer} > class max.pipelines.lib.tokenizer.TextAndVisionTokenizer(model\_path, \*, revision=None, max\_length=None, max\_new\_tokens=None, trust\_remote\_code=False, \*\*unused\_kwargs) Encapsulates creation of TextContext and specific token encode/decode logic. **Parameters:** * **model\_path** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) * **max\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **max\_new\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **trust\_remote\_code** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `apply_chat_template()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.apply_chat_template} > apply\_chat\_template(messages) **Parameters:** **messages** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TokenGeneratorRequestMessage`](core.md#max.pipelines.core.TokenGeneratorRequestMessage) `]` ) **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `decode()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.decode} > async decode(context, encoded, \*\*kwargs) Transformer a provided encoded token array, back into readable text. **Parameters:** * **context** ([`TextAndVisionContext`](core.md#max.pipelines.core.TextAndVisionContext) ) * **encoded** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `encode()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.encode} > async encode(prompt, add\_special\_tokens=True) Transform the provided prompt into a token array. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **add\_special\_tokens** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `eos` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) The end of sequence token for this tokenizer. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) If true, this tokenizer expects messages to have a content property. Text messages are formatted as: ```json { "type": "text", "content": "text content" } ``` instead of the OpenAI spec: ```json { "type": "text", "text": "text content" } ``` NOTE: Multimodal messages omit the content property. Both `image_urls` and `image` content parts are converted to: ```json { "type": "image" } ``` Their content is provided as byte arrays through the top-level property on the request object, i.e., `PipelineTokenizerRequest.images`. ### `new_context()` {#max.pipelines.lib.tokenizer.TextAndVisionTokenizer.new_context} > async new\_context(request) Create a new TextAndVisionContext object, leveraging necessary information like cache\_seq\_id and prompt from TokenGeneratorRequest. **Parameters:** **request** ([`TokenGeneratorRequest`](core.md#max.pipelines.core.TokenGeneratorRequest) ) **Return type:** [*TextAndVisionContext*](core.md#max.pipelines.core.TextAndVisionContext) ## `TextTokenizer` {#max.pipelines.lib.tokenizer.TextTokenizer} > class max.pipelines.lib.tokenizer.TextTokenizer(model\_path, \*, revision=None, max\_length=None, max\_new\_tokens=None, trust\_remote\_code=False, enable\_llama\_whitespace\_fix=False, \*\*unused\_kwargs) Encapsulates creation of TextContext and specific token encode/decode logic. **Parameters:** * **model\_path** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) ) * **revision** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` `None` ) * **max\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **max\_new\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **trust\_remote\_code** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) * **enable\_llama\_whitespace\_fix** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) ### `apply_chat_template()` {#max.pipelines.lib.tokenizer.TextTokenizer.apply_chat_template} > apply\_chat\_template(messages, tools, chat\_template\_options=None) **Parameters:** * **messages** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TokenGeneratorRequestMessage`](core.md#max.pipelines.core.TokenGeneratorRequestMessage) `]` ) * **tools** ([`list`](https://docs.python.org/3/library/stdtypes.html#list) `[` [`TokenGeneratorRequestTool`](core.md#max.pipelines.core.TokenGeneratorRequestTool) `]` `|` `None` ) * **chat\_template\_options** ([`dict`](https://docs.python.org/3/library/stdtypes.html#dict) `[` [`str`](https://docs.python.org/3/library/stdtypes.html#str) `,` [`Any`](https://docs.python.org/3/library/typing.html#typing.Any) `]` `|` `None` ) **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `decode()` {#max.pipelines.lib.tokenizer.TextTokenizer.decode} > async decode(context, encoded, \*\*kwargs) Transformer a provided encoded token array, back into readable text. **Parameters:** * **context** ([`TextContext`](core.md#max.pipelines.core.TextContext) ) * **encoded** ([`ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ) **Return type:** [str](https://docs.python.org/3/library/stdtypes.html#str) ### `encode()` {#max.pipelines.lib.tokenizer.TextTokenizer.encode} > async encode(prompt, add\_special\_tokens=True) Transform the provided prompt into a token array. **Parameters:** * **prompt** ([`str`](https://docs.python.org/3/library/stdtypes.html#str) `|` [`Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence) `[` [`int`](https://docs.python.org/3/library/functions.html#int) `]` ) * **add\_special\_tokens** ([`bool`](https://docs.python.org/3/library/functions.html#bool) ) **Return type:** [*ndarray*](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray) ### `eos` {#max.pipelines.lib.tokenizer.TextTokenizer.eos} > property eos: [int](https://docs.python.org/3/library/functions.html#int) The end of sequence token for this tokenizer. ### `expects_content_wrapping` {#max.pipelines.lib.tokenizer.TextTokenizer.expects_content_wrapping} > property expects\_content\_wrapping: [bool](https://docs.python.org/3/library/functions.html#bool) If true, this tokenizer expects messages to have a content property. Text messages are formatted as: ```json { "type": "text", "content": "text content" } ``` instead of the OpenAI spec: ```json { "type": "text", "text": "text content" } ``` NOTE: Multimodal messages omit the content property. Both `image_urls` and `image` content parts are converted to: ```json { "type": "image" } ``` Their content is provided as byte arrays through the top-level property on the request object, i.e., `PipelineTokenizerRequest.images`. ### `new_context()` {#max.pipelines.lib.tokenizer.TextTokenizer.new_context} > async new\_context(request) Create a new TextContext object, leveraging necessary information like cache\_seq\_id and prompt from TokenGeneratorRequest. **Parameters:** **request** ([`TokenGeneratorRequest`](core.md#max.pipelines.core.TokenGeneratorRequest) ) **Return type:** [*TextContext*](core.md#max.pipelines.core.TextContext) ## `max_tokens_to_generate()` {#max.pipelines.lib.tokenizer.max_tokens_to_generate} > max.pipelines.lib.tokenizer.max\_tokens\_to\_generate(prompt\_size, max\_length, max\_new\_tokens=None) Returns the max number of new tokens to generate. **Parameters:** * **prompt\_size** ([`int`](https://docs.python.org/3/library/functions.html#int) ) * **max\_length** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) * **max\_new\_tokens** ([`int`](https://docs.python.org/3/library/functions.html#int) `|` `None` ) **Return type:** [int](https://docs.python.org/3/library/functions.html#int) | None ## `run_with_default_executor()` {#max.pipelines.lib.tokenizer.run_with_default_executor} > async max.pipelines.lib.tokenizer.run\_with\_default\_executor(fn, \*args) --- ## torch ## `CustomOpLibrary` {#max.torch.CustomOpLibrary} > class max.torch.CustomOpLibrary(kernel\_library) A PyTorch interface to custom operations implemented in Mojo. This API allows for easy passing of PyTorch data as `torch.Tensor` values to the corresponding custom op. `CustomOpLibrary` handles the compilation of the Mojo custom ops and marshalling of data between PyTorch and the executable Mojo code. For example, consider a grayscale operation implemented in Mojo: ```mojo title="my_library/grayscale.mojo" @register("grayscale") struct Grayscale: @staticmethod fn execute[ # The kind of device this is running on: "cpu" or "gpu" target: StaticString, ]( img_out: OutputTensor[dtype = DType.uint8, rank=2], img_in: InputTensor[dtype = DType.uint8, rank=3], ctx: DeviceContextPtr, ) raises: ... ``` You can then use `CustomOpLibrary` to invoke the Mojo operation like so: ```python import torch from max.torch import CustomOpLibrary op_library = CustomOpLibrary("my_library") grayscale_op = op_library.grayscale def grayscale(pic: torch.Tensor) -> torch.Tensor: result = pic.new_empty(pic.shape[:-1]) grayscale_op(result, pic) return result img = (torch.rand(64, 64, 3) * 255).to(torch.uint8) result = grayscale(img) ``` The custom operation produced by `op_library.` will have the same interface as the backing Mojo operation. Each `InputTensor` or `OutputTensor` argument corresponds to a [`torch.Tensor`](https://docs.pytorch.org/docs/stable/tensors.html#tensor-class-reference) value in Python. Each argument corresponding to an `OutputTensor` in the Mojo operation will be modified in-place. **Parameters:** **kernel\_library** (`Path` `|` [`KernelLibrary`](graph/KernelLibrary.md#max.graph.KernelLibrary) ) – The path to a `.mojo` file or a `.mojopkg` with your custom op kernels, or the corresponding library object.