diff --git a/pyproject.toml b/pyproject.toml index 5db523c5..a094ddd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,13 +41,13 @@ parallax = "parallax.cli:main" mac = [ "torch==2.8.0", - "mlx-lm==0.28.0", - "mlx==0.29.1", + "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d", + "mlx==0.29.3", ] gpu = [ - "mlx-lm==0.28.0", - "mlx[cpu]==0.29.1", + "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git@f2b0262824c2a2af82acd783decf638b10710b6d", + "mlx[cpu]==0.29.3", "sglang[all]==0.5.4.post1", ] diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 42912b5d..60ee236f 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -1,144 +1,12 @@ # Copyright © 2025 Apple Inc. -from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Optional, Tuple import mlx.core as mx -import mlx.nn as nn -from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention -from mlx_lm.models.switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_key_value_heads: int - max_position_embeddings: int - num_experts_per_tok: int - num_local_experts: int - shared_intermediate_size: int - num_hidden_layers: int - rms_norm_eps: float - rope_theta: float - rotary_dim: int - vocab_size: int - tie_word_embeddings: bool = False - scoring_func: str = "sigmoid" - head_dim: Optional[int] = None - use_qk_norm: bool = True - - -class MLXMiniMaxAttention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.hidden_dim = hidden_size = args.hidden_size - - self.num_attention_heads = args.num_attention_heads - self.num_key_value_heads = args.num_key_value_heads - self.head_dim = head_dim = args.head_dim or hidden_size // args.num_attention_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * head_dim, bias=False) - self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) - self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) - self.o_proj = nn.Linear(self.num_attention_heads * head_dim, args.hidden_size, bias=False) - - self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False - if self.use_qk_norm: - self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) - - self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - if self.use_qk_norm: - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.o_proj(output) - - -class MLXMiniMaxSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_experts_per_tok = args.num_experts_per_tok - - self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) - self.switch_mlp = SwitchGLU( - args.hidden_size, args.intermediate_size, args.num_local_experts - ) - self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) - - def __call__(self, x: mx.array) -> mx.array: - gates = self.gate(x.astype(mx.float32)) - - scores = mx.sigmoid(gates) - orig_scores = scores - scores = scores + self.e_score_correction_bias - - k = self.num_experts_per_tok - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(orig_scores, inds, axis=-1) - - scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) - scores = scores.astype(x.dtype) - - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - return y - - -class MLXMiniMaxBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.self_attn = MLXMiniMaxAttention(args) - - self.block_sparse_moe = MLXMiniMaxSparseMoeBlock(args) - - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = x + self.self_attn(self.input_layernorm(x), mask, cache) - r = r + self.block_sparse_moe(self.post_attention_layernorm(r)) - return r +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.minimax import MiniMaxAttention as MLXMiniMaxAttention +from mlx_lm.models.minimax import MiniMaxDecoderLayer as MLXMiniMaxBlock +from mlx_lm.models.minimax import ModelArgs class ParallaxMiniMaxAttention(MLXMiniMaxAttention): diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index fd815688..781db486 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -30,7 +30,7 @@ import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from pydantic import BaseModel from starlette.datastructures import State @@ -105,7 +105,7 @@ def __init__( self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} # Load tokenizer for separate detokenizers - model_path = get_model_path(model_path_str)[0] + model_path = _download(model_path_str) config = load_config(model_path) self.model_path_str = model_path_str self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index be03e00f..9ac062e0 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -10,7 +10,7 @@ import mlx.core as mx import safetensors from mlx import nn -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from parallax.server.model import ShardedModel from parallax.utils.tokenizer_utils import load_tokenizer @@ -99,7 +99,7 @@ def load( Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path = get_model_path(self.model_path_str)[0] + model_path = _download(self.model_path_str) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 2bf8a0f5..29e00cd9 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -11,7 +11,7 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( get_tp_group, @@ -34,7 +34,6 @@ monkey_patch_p2p_access_check, ) -from parallax.sglang.monkey_patch import apply_parallax_sglang_monkey_patch from parallax.utils.tokenizer_utils import load_tokenizer logger = logging.getLogger(__name__) @@ -229,8 +228,8 @@ def initialize_sgl_model_runner( - config: model config driven by mlx-lm - tokenizer: tokenizer driven by mlx-lm """ - apply_parallax_sglang_monkey_patch() - model_path = get_model_path(original_model_path)[0] + apply_parallax_monkey_patch() + model_path = _download(original_model_path) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") diff --git a/tests/test_executor.py b/tests/test_executor.py index bbb117b4..4881448f 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,7 +4,7 @@ import pytest from mlx_lm.generate import generate -from mlx_lm.utils import get_model_path, load_model +from mlx_lm.utils import _download, load_model from parallax.server.executor import Executor from parallax.server.request import InitialRequest @@ -12,7 +12,7 @@ MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16" -model_path = get_model_path(MODEL_REPO)[0] +model_path = _download(MODEL_REPO) ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) diff --git a/tests/test_model.py b/tests/test_model.py index a62dd2db..45634d6d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ import mlx.core as mx import pytest from mlx_lm.models.base import create_attention_mask -from mlx_lm.utils import get_model_path, load_model +from mlx_lm.utils import _download, load_model from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader @@ -18,7 +18,7 @@ TOTAL_LAYERS = 28 -model_path = get_model_path(REPO_ID)[0] +model_path = _download(REPO_ID) ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None))