Quantization
MAX allows you to load and run pre-quantized models through both its Python API and CLI. This guide explains quantization concepts and how to work with quantized models in your applications.
Understanding quantization
Quantization reduces the numeric precision of model weights to decrease memory
usage and increase inference speed. For example, models originally trained with
float32
weights can be represented using lower precision types like int8
or
int4
, reducing each scalar value from 32 bits to 8 or 4 bits.
When used properly, quantization does not significantly affect the model accuracy. There are several different quantization encodings that provide different levels of precision and encoding formats, each with its own trade-offs that may work well for some models or graph operations ("ops") but not others. Some models also work great with a mixture of quantization types, so that only certain ops perform low-precision calculations while others retain high precision.
How to load pre-quantized models with MAX
You can load pre-quantized models using two primary approaches:
- By specifying a path to a quantized weight file
- By specifying the quantization encoding format for compatible models
When you have a quantized weight file, you can load it directly using the
--weight-path
argument:
max serve --model-path=meta-llama/Llama-3.1-8B-Instruct \
--weight-path=bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
max serve --model-path=meta-llama/Llama-3.1-8B-Instruct \
--weight-path=bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
MAX automatically detects the quantization format from the weight file. This approach works for models with standard quantization formats like GGUF and AWQ.
For models that have been quantized using specific techniques but don't use a
separate weight file format, you can specify the quantization encoding directly
with the --quantization-encoding
flag:
max generate --model-path=hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 \
--quantization-encoding=gptq \
--prompt "What is the meaning of life?"
max generate --model-path=hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4 \
--quantization-encoding=gptq \
--prompt "What is the meaning of life?"
The --quantization-encoding
flag accepts the following values:
float32
: Full precision 32-bit floating point.bfloat16
: Brain floating point 16-bit format.q4_0
: 4-bit quantization format.q4_k
: 4-bit quantization with K-means clustering.q6_k
: 6-bit quantization with K-means clustering.gptq
: Specialized quantization optimized for transformer-based models.
For more information on the max
CLI, see the MAX CLI
documentation or the MAX Serve API reference .
Quantized layer implementation
For developers building custom models with the MAX Graph API you can implement custom quantized layers. This is useful when:
- You're building a model from scratch using the MAX Graph API
- You need precise control over how quantization is implemented
- You're implementing specialized model architectures that require custom quantized operations
To implement a quantized layer in Python, you'll need to make a few key changes compared to a standard linear layer. Let's look at the differences.
A standard linear layer in MAX might look like this:
from max import nn
from max.dtype import DType
from max.graph import DeviceRef, Weight
class Linear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = Weight(
name="weight",
dtype=DType.float32,
shape=[in_dim, out_dim],
device=DeviceRef.CPU(),
)
self.bias = Weight(name="bias", dtype=DType.float32, shape=[out_dim])
def __call__(self, x):
return x @ self.weight.T.to(x.device) + self.bias.to(x.device)
from max import nn
from max.dtype import DType
from max.graph import DeviceRef, Weight
class Linear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = Weight(
name="weight",
dtype=DType.float32,
shape=[in_dim, out_dim],
device=DeviceRef.CPU(),
)
self.bias = Weight(name="bias", dtype=DType.float32, shape=[out_dim])
def __call__(self, x):
return x @ self.weight.T.to(x.device) + self.bias.to(x.device)
To enable support for GGUF quantization like
Q4_0
,
Q4_K
,
or other encodings, you need to:
- Load weights from the quantized model checkpoint as
uint8
with the appropriate shape. - Replace the standard matrix multiplication
(@)
with theqmatmul
operation. - Specify the quantization encoding to use.
Here's how you might implement a quantized linear layer:
from max import nn
from max.dtype import DType
from max.graph import DeviceRef, Weight, ops
from max.graph.quantization import QuantizationEncoding
class QuantizedLinear(nn.Module):
def __init__(self, in_dim, out_dim, quantization_encoding):
super().__init__()
self.weight = Weight(
name="weight",
# The DType must be uint8.
dtype=DType.uint8,
# This shape must be updated to match the quantized shape
shape=[in_dim, out_dim],
device=DeviceRef.CPU(),
quantization_encoding=quantization_encoding,
)
self.bias = Weight(name="bias", dtype=DType.float32, shape=[out_dim])
def __call__(self, x):
return ops.qmatmul(
self.weight.quantization_encoding, None, x, self.weight.to(x.device)
) + bias.to(x.device)
quantized_linear = QuantizedLinear(in_dim, out_dim, QuantizationEncoding.Q4_0)
from max import nn
from max.dtype import DType
from max.graph import DeviceRef, Weight, ops
from max.graph.quantization import QuantizationEncoding
class QuantizedLinear(nn.Module):
def __init__(self, in_dim, out_dim, quantization_encoding):
super().__init__()
self.weight = Weight(
name="weight",
# The DType must be uint8.
dtype=DType.uint8,
# This shape must be updated to match the quantized shape
shape=[in_dim, out_dim],
device=DeviceRef.CPU(),
quantization_encoding=quantization_encoding,
)
self.bias = Weight(name="bias", dtype=DType.float32, shape=[out_dim])
def __call__(self, x):
return ops.qmatmul(
self.weight.quantization_encoding, None, x, self.weight.to(x.device)
) + bias.to(x.device)
quantized_linear = QuantizedLinear(in_dim, out_dim, QuantizationEncoding.Q4_0)
The MAX graph quantization class defines the available quantization formats supported by MAX. These encodings include:
- Q4_0: 4-bit quantization format
- Q4_K: 4-bit quantization with K-means clustering
- Q5_K: 5-bit quantization with K-means clustering
- Q6_K: 6-bit quantization with K-means clustering
- GPTQ: Specialized quantization optimized for transformer-based models
With this implementation, you can add quantized weights into your MAX models. The
qmatmul
operation handles
the dequantization process during inference, giving you the performance benefits
of quantization without having to manage the low-level details.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!