Python module
rotary_embedding
The rope embedding used within the model.
DeepseekYarnRopeScalingParams
class max.nn.rotary_embedding.DeepseekYarnRopeScalingParams(scaling_factor: float, original_max_position_embeddings: int, beta_fast: int, beta_slow: int, mscale: float, mscale_all_dim: float)
beta_fast
beta_fast*: int*
Fast interpolation rate.
beta_slow
beta_slow*: int*
Slow interpolation rate.
mscale
mscale*: float*
Scaling factor for middle frequencies.
mscale_all_dim
mscale_all_dim*: float*
Scaling factor applied to all dimensions.
original_max_position_embeddings
original_max_position_embeddings*: int*
Original maximum sequence length during training.
scaling_factor
scaling_factor*: float*
Scaling factor for frequency interpolation.
DeepseekYarnRotaryEmbedding
class max.nn.rotary_embedding.DeepseekYarnRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: DeepseekYarnRopeScalingParams | None = None)
Deepseek’s YaRN (Yet another RoPE eNhancement) Rotary Position Embedding layer.
Unlike Llama3RotaryEmbedding, the dim argument here is the rope dimension of the model, not the hidden dimension.
freqs_cis_base()
freqs_cis_base() → TensorValue
Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
-
Returns:
The frequency tensor for complex exponentials with shape : (max_seq_len, rope_dim // 2, 2)
scaling_params
scaling_params*: DeepseekYarnRopeScalingParams | None* = None
LinearScalingParams
class max.nn.rotary_embedding.LinearScalingParams(factor: float)
factor
factor*: float*
Main scaling factor for the frequency components of the rope.
Llama3RopeScalingParams
class max.nn.rotary_embedding.Llama3RopeScalingParams(factor: float, low_freq_factor: float, high_freq_factor: float, orig_max_position: int)
factor
factor*: float*
Main scaling factor for the frequency components of the rope.
high_freq_factor
high_freq_factor*: float*
Factor to scale the high frequency components of the rope.
low_freq_factor
low_freq_factor*: float*
Factor to scale the low frequency components of the rope.
orig_max_position
orig_max_position*: int*
The original maximum position length supported by the model.
Llama3RotaryEmbedding
class max.nn.rotary_embedding.Llama3RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True, scaling_params: Llama3RopeScalingParams | None = None)
RotaryEmbedding for Llama3 that takes rope scaling into account.
scaling_params
scaling_params*: Llama3RopeScalingParams | None* = None
Scaling parameters to enable llama to function with a longer context length.
OptimizedRotaryEmbedding
class max.nn.rotary_embedding.OptimizedRotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)
Optimized version of RotaryEmbedding using 2D frequency tensor representation.
freqs_cis
property freqs_cis
RotaryEmbedding
class max.nn.rotary_embedding.RotaryEmbedding(dim: int, n_heads: int, theta: float, max_seq_len: int, device: DeviceRef, head_dim: int | None = None, _freqs_cis: Value | BufferValue | TensorValue | Shape | Dim | int | float | integer | floating | ndarray | None = None, interleaved: bool = True)
RotaryEmbedding layer to calculate and apply the frequency tensor for complex exponentials.
device
device*: DeviceRef*
dim
dim*: int*
freqs_cis
property freqs_cis*: TensorValue*
freqs_cis_base()
freqs_cis_base() → TensorValue
Computes the frequency tensor for complex exponentials (cis) for a given seq_len. Tensor is scaled with theta parameter. Required to apply Rotary Position Embedding (RoPE) to tensor. See ‘Roformer: Enhanced Transformer with Rotary Embedding’ (arxiv.org/pdf/2104.09864).
-
Returns:
The frequency tensor for complex exponentials with shape (max_seq_len * 2, head_dim / 2, 2)
head_dim
head_dim = dim // n_heads if not specified in the config.
interleaved
interleaved*: bool* = True
max_seq_len
max_seq_len*: int*
The maximum sequence length for model’s input.
n_heads
n_heads*: int*
theta
theta*: float*
Hyperparameter used to control the frequency scaling of the sinusoidal components of the embeddings.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!