Skip to main content
Log in

Python module

rotary_embedding

The rope embedding used within the model.

DeepseekYarnRopeScalingParams

class max.nn.rotary_embedding.DeepseekYarnRopeScalingParams(scaling_factor: float, original_max_position_embeddings: int, beta_fast: int, beta_slow: int, mscale: float, mscale_all_dim: float)

beta_fast

beta_fast*: int*

Fast interpolation rate.

beta_slow

beta_slow*: int*

Slow interpolation rate.

mscale

mscale*: float*

Scaling factor for middle frequencies.

mscale_all_dim

mscale_all_dim*: float*

Scaling factor applied to all dimensions.

original_max_position_embeddings

original_max_position_embeddings*: int*

Original maximum sequence length during training.

scaling_factor

scaling_factor*: float*

Scaling factor for frequency interpolation.

DeepseekYarnRotaryEmbedding

class max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: DeepseekYarnRopeScalingParams | None = None)

Deepseek’s YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.

Unlike Llama3RotaryEmbedding, the dim argument here is the rope dimension of the model, not the hidden dimension.

freqs_cis_base()

freqs_cis_base() → TensorValue

Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).

  • Returns:

    The frequency tensor for complex exponentials with shape : (max_seq_len, rope_dim // 2, 2)

scaling_params

scaling_params*: DeepseekYarnRopeScalingParams | None* = None

LinearScalingParams

class max.nn.rotary_embedding.LinearScalingParams(factor: float)

factor

factor*: float*

Main scaling factor for the frequency components of the rope.

Llama3RopeScalingParams

class max.nn.rotary_embedding.Llama3RopeScalingParams(factor: float, low_freq_factor: float, high_freq_factor: float, orig_max_position: int)

factor

factor*: float*

Main scaling factor for the frequency components of the rope.

high_freq_factor

high_freq_factor*: float*

Factor to scale the high frequency components of the rope.

low_freq_factor

low_freq_factor*: float*

Factor to scale the low frequency components of the rope.

orig_max_position

orig_max_position*: int*

The original maximum position length supported by the model.

Llama3RotaryEmbedding

class max.nn.rotary_embedding.Llama3RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: Llama3RopeScalingParams | None = None)

RotaryEmbedding for Llama3 that takes rope scaling into account.

scaling_params

scaling_params*: Llama3RopeScalingParams | None* = None

Scaling parameters to enable llama to function with a longer context length.

OptimizedRotaryEmbedding

class max.nn.rotary_embedding.OptimizedRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)

Optimized version of RotaryEmbedding using 2D frequency tensor representation.

freqs_cis

property freqs_cis

RotaryEmbedding

class max.nn.rotary_embedding.RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)

RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.

device

device*: DeviceRef*

dim

dim*: int*

freqs_cis

property freqs_cis*: TensorValue*

freqs_cis_base()

freqs_cis_base() → TensorValue

Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).

  • Returns:

    The frequency tensor for complex exponentials with shape (max_seq_len * 2, head_dim / 2, 2)

head_dim

head_dim*: int | None* = None

head_dim = dim // n_heads if not specified in the config.

interleaved

interleaved*: bool* = True

max_seq_len

max_seq_len*: int*

The maximum sequence length for model’s input.

n_heads

n_heads*: int*

theta

theta*: float*

Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings.