Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ parallax = "parallax.cli:main"

mac = [
"torch==2.8.0",
"mlx-lm==0.28.0",
"mlx==0.29.1",
"mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d",
"mlx==0.29.3",
]

gpu = [
"mlx-lm==0.28.0",
"mlx[cpu]==0.29.1",
"mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d",
"mlx[cpu]==0.29.3",
"sglang[all]==0.5.4.post1",
]

Expand Down
142 changes: 5 additions & 137 deletions src/parallax/models/minimax.py
Original file line number Diff line number Diff line change
@@ -1,144 +1,12 @@
# Copyright Β© 2025 Apple Inc.

from dataclasses import dataclass
from typing import Any, Optional, Tuple
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention
from mlx_lm.models.switch_layers import SwitchGLU


@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
intermediate_size: int
num_attention_heads: int
num_key_value_heads: int
max_position_embeddings: int
num_experts_per_tok: int
num_local_experts: int
shared_intermediate_size: int
num_hidden_layers: int
rms_norm_eps: float
rope_theta: float
rotary_dim: int
vocab_size: int
tie_word_embeddings: bool = False
scoring_func: str = "sigmoid"
head_dim: Optional[int] = None
use_qk_norm: bool = True


class MLXMiniMaxAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.hidden_dim = hidden_size = args.hidden_size

self.num_attention_heads = args.num_attention_heads
self.num_key_value_heads = args.num_key_value_heads
self.head_dim = head_dim = args.head_dim or hidden_size // args.num_attention_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * head_dim, bias=False)
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False)
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False)
self.o_proj = nn.Linear(self.num_attention_heads * head_dim, args.hidden_size, bias=False)

self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False
if self.use_qk_norm:
self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps)
self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps)

self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

if self.use_qk_norm:
queries = self.q_norm(queries)
keys = self.k_norm(keys)

queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

return self.o_proj(output)


class MLXMiniMaxSparseMoeBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.num_experts_per_tok = args.num_experts_per_tok

self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False)
self.switch_mlp = SwitchGLU(
args.hidden_size, args.intermediate_size, args.num_local_experts
)
self.e_score_correction_bias = mx.zeros((args.num_local_experts,))

def __call__(self, x: mx.array) -> mx.array:
gates = self.gate(x.astype(mx.float32))

scores = mx.sigmoid(gates)
orig_scores = scores
scores = scores + self.e_score_correction_bias

k = self.num_experts_per_tok
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(orig_scores, inds, axis=-1)

scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20)
scores = scores.astype(x.dtype)

y = self.switch_mlp(x, inds)
y = (y * scores[..., None]).sum(axis=-2)
return y


class MLXMiniMaxBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()

self.self_attn = MLXMiniMaxAttention(args)

self.block_sparse_moe = MLXMiniMaxSparseMoeBlock(args)

self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = x + self.self_attn(self.input_layernorm(x), mask, cache)
r = r + self.block_sparse_moe(self.post_attention_layernorm(r))
return r
from mlx_lm.models.base import scaled_dot_product_attention
from mlx_lm.models.minimax import MiniMaxAttention as MLXMiniMaxAttention
from mlx_lm.models.minimax import MiniMaxDecoderLayer as MLXMiniMaxBlock
from mlx_lm.models.minimax import ModelArgs


class ParallaxMiniMaxAttention(MLXMiniMaxAttention):
Expand Down
4 changes: 2 additions & 2 deletions src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import zmq.asyncio
from fastapi.responses import ORJSONResponse, StreamingResponse
from mlx_lm.tokenizer_utils import StreamingDetokenizer
from mlx_lm.utils import get_model_path, load_config
from mlx_lm.utils import _download, load_config
from pydantic import BaseModel
from starlette.datastructures import State

Expand Down Expand Up @@ -105,7 +105,7 @@ def __init__(
self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True)
self.processing_requests: Dict[str, HTTPRequestInfo] = {}
# Load tokenizer for separate detokenizers
model_path = get_model_path(model_path_str)[0]
model_path = _download(model_path_str)
config = load_config(model_path)
self.model_path_str = model_path_str
self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
Expand Down
4 changes: 2 additions & 2 deletions src/parallax/server/shard_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import mlx.core as mx
import safetensors
from mlx import nn
from mlx_lm.utils import get_model_path, load_config
from mlx_lm.utils import _download, load_config

from parallax.server.model import ShardedModel
from parallax.utils.tokenizer_utils import load_tokenizer
Expand Down Expand Up @@ -99,7 +99,7 @@ def load(
Returns:
A tuple containing the loaded sharded MLX model and its configuration dictionary.
"""
model_path = get_model_path(self.model_path_str)[0]
model_path = _download(self.model_path_str)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))

Expand Down
7 changes: 3 additions & 4 deletions src/parallax/sglang/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import sglang
import sglang.srt.distributed.parallel_state
import torch
from mlx_lm.utils import get_model_path, load_config
from mlx_lm.utils import _download, load_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
get_tp_group,
Expand All @@ -34,7 +34,6 @@
monkey_patch_p2p_access_check,
)

from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch
from parallax.utils.tokenizer_utils import load_tokenizer

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -229,8 +228,8 @@ def initialize_sgl_model_runner(
- config: model config driven by mlx-lm
- tokenizer: tokenizer driven by mlx-lm
"""
apply_parallax_sglang_monkey_patch()
model_path = get_model_path(original_model_path)[0]
apply_parallax_monkey_patch()
model_path = _download(original_model_path)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
dtype = config.get("torch_dtype", "bfloat16")
Expand Down
4 changes: 2 additions & 2 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import pytest
from mlx_lm.generate import generate
from mlx_lm.utils import get_model_path, load_model
from mlx_lm.utils import _download, load_model

from parallax.server.executor import Executor
from parallax.server.request import InitialRequest
from parallax.utils.tokenizer_utils import load_tokenizer

MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16"

model_path = get_model_path(MODEL_REPO)[0]
model_path = _download(MODEL_REPO)
ref_model, ref_config = load_model(model_path)
ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None))

Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import mlx.core as mx
import pytest
from mlx_lm.models.base import create_attention_mask
from mlx_lm.utils import get_model_path, load_model
from mlx_lm.utils import _download, load_model

from parallax.server.server_info import ShardedModelInfo
from parallax.server.shard_loader import MLXModelLoader
Expand All @@ -18,7 +18,7 @@
TOTAL_LAYERS = 28


model_path = get_model_path(REPO_ID)[0]
model_path = _download(REPO_ID)
ref_model, ref_config = load_model(model_path)
ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None))

Expand Down