Python module
lora
AttentionWithRopeAndLoRA
class max.nn.lora.AttentionWithRopeAndLoRA(*, rope, num_attention_heads, num_key_value_heads, hidden_size, kv_params, devices=None, dtype=float32, linear_cls=<class 'max.nn.lora.linear_lora.LinearLoRA'>, stacked_qkv=False, scale=None, has_bias=False, float8_config=None, clip_qkv=None)
Initializes the LoRA-enabled attention layer.
-
Parameters:
-
- rope (RotaryEmbedding) – The rope layer to borrow the freqs_cis value from.
- num_attention_heads (int) – The number of attention heads.
- num_key_value_heads (int) – Number of key/value heads.
- hidden_size (int) – The dimension of the hidden states.
- kv_params (KVCacheParams) – KV Cache Params, including the number of kv heads, the head dim, and data type.
- dtype (DType) – DType of the QKV and output projection weights.
- devices (list[DeviceRef] | None) – Device to place the weights and run the computation. If multiple are provided, the first device is used. Use DistributedAttentionWithRope to use all devices during attention computation.
- linear_cls (Callable[..., Linear]) – Linear class to use for the outputs dense layer.
- stacked_qkv (bool) – Whether the weights are stacked together.
- scale (float | None) – Value used to scale the results of the attention output.
- has_bias (bool) – Whether to use an attention bias.
- clip_qkv (float | None) – If provided, the QKV weights are clamped between [-clip_qkv, clip_qkv]
- float8_config (Float8Config | None)
fused_qkv_lora()
fused_qkv_lora(x, kv_collection, input_row_offsets, layer_idx)
Computes fused query, key, and value LoRAs with ragged input.
-
Parameters:
-
- x (TensorValue) – The input tensor of shape [total_tokens, hidden_dim].
- qkv_loras (list[LinearLoRA]) – List of 3 LinearLoRA modules for Q, K, and V projections.
- input_row_offsets (TensorValue) – 1D tensor indicating the start index of each sequence in x.
- kv_collection (PagedKVCacheCollection) – The key/value cache collection structure.
- layer_idx (TensorValue) – Index of the current transformer layer (used for caching).
-
Returns:
-
The query projections.
-
Return type:
-
Raises:
-
ValueError – If ‘set_lora_batch_info’ has not been called on the LoRAs.
qkv_loras
property qkv_loras: list[LinearLoRA]
rope
rope: RotaryEmbedding
LinearLoRA
class max.nn.lora.LinearLoRA(in_dim, out_dim, max_num_loras, max_lora_rank, dtype, device, has_bias=False, has_lora_bias=False, name=None, quantization_encoding=None, float8_config=None)
Applies a linear transformation and LoRA to input:
.
Example:
linear_layer = LinearLoRA(
in_dim=256,
out_dim=128,
max_lora_rank=16,
max_num_loras=100,
dtype=dtype.float32,
device=DeviceRef.GPU(),
has_bias=True,
has_lora_bias=True,
name="lora_linear"
)
lora_ids: TensorValue # shape: [max_num_loras,]
lora_ranks: TensorValue # shape: [max_num_loras,]
input_row_offsets: TensorValue
linear_layer.set_lora_batch_info(lora_ids, lora_ranks, input_row_offsets)
input_tensor: TensorValue
output = linear_layer(input_tensor)
-
Parameters:
apply_lora()
apply_lora(x)
-
Parameters:
-
x (TensorValue)
-
Return type:
set_lora_batch_info()
set_lora_batch_info(lora_ids, lora_ranks, lora_grouped_offsets)
-
Parameters:
-
- lora_ids (TensorValue)
- lora_ranks (TensorValue)
- lora_grouped_offsets (TensorValue)
-
Return type:
-
None
SupportsLoRA
class max.nn.lora.SupportsLoRA(*args, **kwargs)
Base class for supporting LoRA functionality in Modules
apply_lora()
apply_lora(x)
-
Parameters:
-
x (TensorValue)
-
Return type:
set_lora_batch_info()
set_lora_batch_info(lora_ids, lora_ranks, lora_grouped_offsets)
-
Parameters:
-
- lora_ids (TensorValue)
- lora_ranks (TensorValue)
- lora_grouped_offsets (TensorValue)
-
Return type:
-
None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!