From d3c1d50cc116242b4498c4afde5c95202ac36af7 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 31 Oct 2025 12:46:59 +0800 Subject: [PATCH 1/6] update --- pyproject.toml | 6 +- src/parallax/models/minimax.py | 136 +--------------------------- src/parallax/server/http_server.py | 4 +- src/parallax/server/shard_loader.py | 4 +- src/parallax/sglang/model_runner.py | 4 +- tests/test_executor.py | 4 +- tests/test_model.py | 4 +- 7 files changed, 16 insertions(+), 146 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5db523c5..9112ad35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,13 +41,13 @@ parallax = "parallax.cli:main" mac = [ "torch==2.8.0", - "mlx-lm==0.28.0", - "mlx==0.29.1", + "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d", + "mlx==0.29.3", ] gpu = [ "mlx-lm==0.28.0", - "mlx[cpu]==0.29.1", + "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d", "sglang[all]==0.5.4.post1", ] diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 42912b5d..7d88cf99 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -6,139 +6,9 @@ import mlx.core as mx import mlx.nn as nn from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention -from mlx_lm.models.switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_key_value_heads: int - max_position_embeddings: int - num_experts_per_tok: int - num_local_experts: int - shared_intermediate_size: int - num_hidden_layers: int - rms_norm_eps: float - rope_theta: float - rotary_dim: int - vocab_size: int - tie_word_embeddings: bool = False - scoring_func: str = "sigmoid" - head_dim: Optional[int] = None - use_qk_norm: bool = True - - -class MLXMiniMaxAttention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.hidden_dim = hidden_size = args.hidden_size - - self.num_attention_heads = args.num_attention_heads - self.num_key_value_heads = args.num_key_value_heads - self.head_dim = head_dim = args.head_dim or hidden_size // args.num_attention_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * head_dim, bias=False) - self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) - self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) - self.o_proj = nn.Linear(self.num_attention_heads * head_dim, args.hidden_size, bias=False) - - self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False - if self.use_qk_norm: - self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) - - self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - if self.use_qk_norm: - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.o_proj(output) - - -class MLXMiniMaxSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_experts_per_tok = args.num_experts_per_tok - - self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) - self.switch_mlp = SwitchGLU( - args.hidden_size, args.intermediate_size, args.num_local_experts - ) - self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) - - def __call__(self, x: mx.array) -> mx.array: - gates = self.gate(x.astype(mx.float32)) - - scores = mx.sigmoid(gates) - orig_scores = scores - scores = scores + self.e_score_correction_bias - - k = self.num_experts_per_tok - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(orig_scores, inds, axis=-1) - - scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) - scores = scores.astype(x.dtype) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - return y - - -class MLXMiniMaxBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.self_attn = MLXMiniMaxAttention(args) - - self.block_sparse_moe = MLXMiniMaxSparseMoeBlock(args) - - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = x + self.self_attn(self.input_layernorm(x), mask, cache) - r = r + self.block_sparse_moe(self.post_attention_layernorm(r)) - return r +from mlx_lm.models.minimax import MiniMaxAttention as MLXMiniMaxAttention +from mlx_lm.models.minimax import MiniMaxDecoderLayer as MLXMiniMaxBlock +from mlx_lm.models.minimax import ModelArgs class ParallaxMiniMaxAttention(MLXMiniMaxAttention): diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index fd815688..781db486 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -30,7 +30,7 @@ import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from pydantic import BaseModel from starlette.datastructures import State @@ -105,7 +105,7 @@ def __init__( self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} # Load tokenizer for separate detokenizers - model_path = get_model_path(model_path_str)[0] + model_path = _download(model_path_str) config = load_config(model_path) self.model_path_str = model_path_str self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index be03e00f..9ac062e0 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -10,7 +10,7 @@ import mlx.core as mx import safetensors from mlx import nn -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from parallax.server.model import ShardedModel from parallax.utils.tokenizer_utils import load_tokenizer @@ -99,7 +99,7 @@ def load( Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path = get_model_path(self.model_path_str)[0] + model_path = _download(self.model_path_str) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 0692dd34..4554d80c 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -12,7 +12,7 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( get_tp_group, @@ -555,7 +555,7 @@ def initialize_sgl_model_runner( - tokenizer: tokenizer driven by mlx-lm """ apply_parallax_monkey_patch() - model_path = get_model_path(original_model_path)[0] + model_path = _download(original_model_path) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") diff --git a/tests/test_executor.py b/tests/test_executor.py index 1dd16093..8c05abca 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,7 +4,7 @@ import pytest from mlx_lm.generate import generate -from mlx_lm.utils import get_model_path, load_model +from mlx_lm.utils import _download, load_model from parallax.server.executor import Executor from parallax.server.request import InitialRequest @@ -12,7 +12,7 @@ MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16" -model_path = get_model_path(MODEL_REPO)[0] +model_path = _download(MODEL_REPO) ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) diff --git a/tests/test_model.py b/tests/test_model.py index a62dd2db..45634d6d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ import mlx.core as mx import pytest from mlx_lm.models.base import create_attention_mask -from mlx_lm.utils import get_model_path, load_model +from mlx_lm.utils import _download, load_model from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader @@ -18,7 +18,7 @@ TOTAL_LAYERS = 28 -model_path = get_model_path(REPO_ID)[0] +model_path = _download(REPO_ID) ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) From f2eba32b4d64ff69e1557cd97725c5ca3e4b11f8 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 31 Oct 2025 12:50:37 +0800 Subject: [PATCH 2/6] update --- src/parallax/models/minimax.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 7d88cf99..60ee236f 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -1,11 +1,9 @@ # Copyright © 2025 Apple Inc. -from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import mlx.core as mx -import mlx.nn as nn -from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention +from mlx_lm.models.base import scaled_dot_product_attention from mlx_lm.models.minimax import MiniMaxAttention as MLXMiniMaxAttention from mlx_lm.models.minimax import MiniMaxDecoderLayer as MLXMiniMaxBlock from mlx_lm.models.minimax import ModelArgs From 02159d0b9e6912a1046414df9c68193f22fe1d53 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Fri, 31 Oct 2025 13:50:38 +0800 Subject: [PATCH 3/6] update --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9112ad35..a094ddd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,8 @@ mac = [ ] gpu = [ - "mlx-lm==0.28.0", "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d", + "mlx[cpu]==0.29.3", "sglang[all]==0.5.4.post1", ] From a389633bbbe60c08f03a41658438c7d5074006ee Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 20:16:22 +0800 Subject: [PATCH 4/6] update --- src/parallax/sglang/model_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 8afc7a59..29e00cd9 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -34,7 +34,6 @@ monkey_patch_p2p_access_check, ) -from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) From 7dd4a4a0b6ca6663292ed818f719918e5a995f5e Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 21:17:29 +0800 Subject: [PATCH 5/6] update null --- tests/test_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_model.py b/tests/test_model.py index 45634d6d..ef2bbaad 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,6 +9,7 @@ from mlx_lm.models.base import create_attention_mask from mlx_lm.utils import _download, load_model + from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader from parallax.utils.tokenizer_utils import load_tokenizer From c82b084e00ac7c5fcb328bb0ce3092731d05cc79 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 3 Nov 2025 21:50:36 +0800 Subject: [PATCH 6/6] rebase --- tests/test_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index ef2bbaad..45634d6d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,6 @@ from mlx_lm.models.base import create_attention_mask from mlx_lm.utils import _download, load_model - from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader from parallax.utils.tokenizer_utils import load_tokenizer