Skip to main content

Quantize your graph weights

Quantization is an optimization technique that reduces the numeric precision of weights in a model. For example, models are usually trained with float32 weights, but you can quantize the values to a lower precision type such as int8 or int4. That is, instead of storing each scalar value with 32-bits, you can use just 8 or 4 bits. This reduces the computational and memory demands during inference, which makes the model faster and compatible with more systems.

To support quantization with MAX Graph, we’ve built an API designed for low-level graph engineers who want to quantize specific weights in a model. This API does not quantize an entire model. Like the MAX Graph API, this is a low-level API meant for engineers who want to build high-performance graphs in a systems programming language—specifically, in Mojo.

If you just want to read some code, check out the Quantize TinyStories pipeline, which quantizes a 15-million parameter version of Llama 2 with Q4_0 (4-bit) encoding.

Overview

When used properly, quantization does not significantly affect the model accuracy. There are several different quantization encodings that provide different levels of precision and encoding formats, each with its own trade-offs that may work well for some models or graph operations ("ops") but not others. Some models also work great with a mixture of quantization types, so that only certain ops perform low-precision calculations while others retain high precision.

To support this mixed-precision strategy, the quantization API in MAX Graph is declarative. That means you can quantize the weights in your model explicitly as you see fit, rather than pick one quantization format for the whole model. You can quantize different weights with different encodings, write custom ops that understand your quantizations, and even implement your own quantization encodings.

The primary API is the quantize() function (from the QuantizationEncoding trait), which takes a float32 tensor and returns a quantized tensor as a uint8 bytes buffer (it’s a type-erased blob of bytes that can be in any quantization encoding). You can call quantize() using one of the existing quantization encodings, such as Q4_0, Q4_K, and Q6_K (from GGML). Then, add the quantized tensor as a node in your graph.

Because the quantized data is just a blob of bytes with a special encoding for the values and scaling factor, any op that you pass this data into must know how to dequantize that data in order to perform its calculation with the full-precision float32 value.

Currently, the only op included in MAX Graph that can operate on quantized data is qmatmul(). This takes a float32 tensor and a quantized tensor, and returns a float32 tensor. This op alone allows you to build a variety of quantized transformer models. However, using quantized weights with any other op in max.graph.ops doesn’t work as is, because they all expect float32 inputs.

Now let’s look at some simple code examples.

Quantize some weights

When you build a graph with MAX Graph, each batch of weights begins as a Tensor that you set in the graph as a constant (a node created with Graph.constant()). When you want to quantize those weights, just pass the Tensor to the quantize() method from the encoding type you want to use before you add it to the graph.

For example, the following code quantizes a tensor with Q4_0Encoding (4-bit encoding), performs quantized matmul (using qmatmul()), and prints the results:

from max.tensor import Tensor, TensorShape
from max.engine import InferenceSession
from max.graph import Graph, TensorType
from max.graph.quantization import Q4_0Encoding
from max.graph.ops.quantized_ops import qmatmul

def main():
graph = Graph(TensorType(DType.float32, 32, 64))

# Perform matmul with the full-precision constant
# constant_value = Tensor[DType.float32](TensorShape(64, 32), 0.15)
# constant = graph.constant(constant_value)
# matmul = graph[0] @ constant

# Perform matmul with the quantized constant (transposed)
constant_value = Tensor[DType.float32](TensorShape(32, 64), 0.15)
quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)
matmul = qmatmul[Q4_0Encoding](graph[0], quantized_constant)

graph.output(matmul)

session = InferenceSession()
model = session.load(graph)

input = Tensor[DType.float32](TensorShape(32, 64), 0.5)
results = model.execute("input0", input^)
output = results.get[DType.float32]("output0")
print(output)
from max.tensor import Tensor, TensorShape
from max.engine import InferenceSession
from max.graph import Graph, TensorType
from max.graph.quantization import Q4_0Encoding
from max.graph.ops.quantized_ops import qmatmul

def main():
graph = Graph(TensorType(DType.float32, 32, 64))

# Perform matmul with the full-precision constant
# constant_value = Tensor[DType.float32](TensorShape(64, 32), 0.15)
# constant = graph.constant(constant_value)
# matmul = graph[0] @ constant

# Perform matmul with the quantized constant (transposed)
constant_value = Tensor[DType.float32](TensorShape(32, 64), 0.15)
quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)
matmul = qmatmul[Q4_0Encoding](graph[0], quantized_constant)

graph.output(matmul)

session = InferenceSession()
model = session.load(graph)

input = Tensor[DType.float32](TensorShape(32, 64), 0.5)
results = model.execute("input0", input^)
output = results.get[DType.float32]("output0")
print(output)

You probably noticed this code also includes the full-precision matmul as an option. If you toggle the comments on lines 11-13 and 16-19, and run it again, you can see for yourself how close the results are even though the quantized constant uses just 1/8th of the memory (4-bits vs 32-bits).

No matter which quantization encoding you choose, the quantize() method works the same—it takes in a full-precision value as a Tensor value and returns the quantized value as a Tensor.

Alternatively, you can use Graph.quantize() to combine these two lines:

    quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)
    quantized_value = Q4_0Encoding.quantize(constant_value)
quantized_constant = graph.constant(quantized_value)

Into one line:

    quantized_constant = graph.quantize[Q4_0Encoding](constant_value)
    quantized_constant = graph.quantize[Q4_0Encoding](constant_value)

To see how we quantized a real model with this API, check out the Quantize TinyStories pipeline, which is a 15-million parameter model quantized with 4-bit encoding down to about 10MB.

Save and load tensors to disk

To avoid quantizing your weights every time you load a model, you can save and load them from disk using the save() and load() functions. For example:

from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.max")

def read_from_disk():
tensors = load("/path/to/checkpoint.max")
x = tensors.get[DType.int32]("x")
from max.graph.checkpoint import load, save, TensorDict
from max.tensor import Tensor, TensorShape

def write_to_disk():
tensors = TensorDict()
tensors.set("x", Tensor[DType.int32](TensorShape(1, 2, 2), 1, 2, 3, 4))
tensors.set("y", Tensor[DType.float32](TensorShape(10, 5), -1.23))
save(tensors, "/path/to/checkpoint.max")

def read_from_disk():
tensors = load("/path/to/checkpoint.max")
x = tensors.get[DType.int32]("x")

The TensorDict type is just a dictionary type for named tensors.