Skip to main content

Python module

moe

Mixture of Experts (MoE) module.

MoE

class max.nn.moe.MoE(devices, hidden_dim, num_experts, num_experts_per_token, moe_dim, gate_cls=<class 'max.nn.moe.moe.MoEGate'>, has_shared_experts=False, shared_experts_dim=0, ep_size=1, dtype=bfloat16, apply_router_weight_first=False)

Implementation of Mixture of Experts (MoE).

Parameters:

  • devices (list[DeviceRef]) – List of devices to use for the MoE.
  • hidden_dim (int) – The dimension of the hidden state.
  • num_experts (int) – The number of experts.
  • num_experts_per_token (int) – The number of experts per token.
  • moe_dim (int) – The intermediate dimension of each expert.
  • gate_cls (Callable[..., MoEGate]) – The model specific gate implementation.
  • has_shared_experts (bool) – Whether to use shared experts.
  • shared_experts_dim (int) – The dimension of the shared experts.
  • ep_size (int) – The expert parallelism size.
  • dtype (DType) – The data type of the MoE.
  • apply_router_weight_first (bool) – Whether to apply the router weight first.

down_proj

property down_proj: TensorValue

gate_up_proj

property gate_up_proj: TensorValue

shard()

shard(devices)

Create sharded views of this MoE module across multiple devices.

Parameters:

devices (Iterable[DeviceRef]) – Iterable of devices to place the shards on.

Returns:

List of sharded MoE instances, one for each device.

Return type:

list[MoE]

sharding_strategy

property sharding_strategy: ShardingStrategy | None

Get the sharding strategy for the module.

MoEGate

class max.nn.moe.MoEGate(devices, hidden_dim, num_experts, num_experts_per_token, dtype)

Gate module for MoE.

Parameters:

  • devices (list[DeviceRef]) – List of devices to use for the MoEGate.
  • hidden_dim (int) – The dimension of the hidden state.
  • num_experts (int) – The number of experts.
  • num_experts_per_token (int) – The number of experts per token.
  • dtype (DType) – The data type of the MoEGate.

Was this page helpful?