Skip to main content

Python class

Weight

Weight

class max.graph.Weight(*args, **kwargs)

Bases: TensorValue

Represents a value in a Graph that can be loaded at a later time.

Weights can be initialized outside of a Graph and are lazily-added to the parent graph when used. If there is no parent graph when a weight is used, an error will be raised.

Initializes a TensorValue from a tensor-like value.

Parameters:

value – The value to wrap. Can be an MLIR tensor value, another TensorValue, a Dim, or a Shape.

align

align: int | None

device

property device: DeviceRef

Returns the device of the TensorValue.

dtype

property dtype: DType

Returns the tensor data type.

The following example demonstrates how to access the data type of a tensor:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("dtype_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Access tensor data type
    print(f"Data type: {tensor.dtype}")  # Output: DType.float32

original_dtype_and_shape

property original_dtype_and_shape: tuple[DType, Shape]

The original dtype and shape of this weight.

This property should be used to store the original weight’s dtype and shape the quantization encoding forces the weight to be loaded as uint8.

quantization_encoding

quantization_encoding: QuantizationEncoding | None

shape

property shape: Shape

Returns the shape of the TensorValue.

The following example demonstrates how to access the shape of a tensor:

import numpy as np
from max.dtype import DType
from max.graph import Graph, ops

# Create a 2x2 matrix
matrix = np.array([[1, 2], [3, 4]], dtype=np.float32)

# Create a Graph context to work with tensors
with Graph("shape_demo") as graph:
    # Create a constant tensor from the matrix
    tensor = ops.constant(matrix, dtype=DType.float32, device=DeviceRef.CPU())

    # Access tensor shape
    print(f"Shape: {tensor.shape}")  # Shape: [Dim(2), Dim(2)]

shard()

shard(devices)

Creates sharded views of this Weight across multiple devices.

This Weight must have sharding_strategy defined. The shard objects returned are also Weight objects, but cannot be sharded further.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded weights, one for each device.

Return type:

list[Weight]

shard_idx

shard_idx: int | None

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Gets the weight sharding strategy.

Was this page helpful?