Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
# release builds of dstack-gateway depend on a PyPI version of dstack instead
"dstack[gateway] @ git+https://github.com/dstackai/dstack.git@master",
"dstack[gateway] @ git+https://github.com/Bihan/dstack.git@add_sglang_router_support",
]

[project.optional-dependencies]
sglang = ["sglang-router==0.2.2"]

[tool.setuptools.package-data]
"dstack.gateway" = [
"resources/systemd/*",
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def create_gateway(
image_id=aws_resources.get_gateway_image_id(ec2_client),
instance_type="t3.micro",
iam_instance_profile=None,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
tags=tags,
security_group_id=security_group_id,
spot=False,
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def create_gateway(
image_reference=_get_gateway_image_ref(),
vm_size="Standard_B1ms",
instance_name=instance_name,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
ssh_pub_keys=[configuration.ssh_key_pub],
spot=False,
disk_size=30,
Expand Down
19 changes: 14 additions & 5 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Volume,
Expand Down Expand Up @@ -876,7 +877,7 @@ def get_run_shim_script(
]


def get_gateway_user_data(authorized_key: str) -> str:
def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
return get_cloud_config(
package_update=True,
packages=[
Expand All @@ -892,7 +893,7 @@ def get_gateway_user_data(authorized_key: str) -> str:
"s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/",
"/etc/nginx/nginx.conf",
],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router))],
],
ssh_authorized_keys=[authorized_key],
)
Expand Down Expand Up @@ -1021,16 +1022,24 @@ def get_dstack_gateway_wheel(build: str) -> str:
r.raise_for_status()
build = r.text.strip()
logger.debug("Found the latest gateway build: %s", build)
return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
# return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm certain it's not supposed to be hard-coded, we need to get the dynamic URL back.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right. In production, it will not be hardcoded and the hardcoded URL will be replaced by return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl



def get_dstack_gateway_commands() -> List[str]:
def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
build = get_dstack_runner_version()
wheel = get_dstack_gateway_wheel(build)
# Use router type directly as pip extra
if router:
gateway_package = f"dstack-gateway[{router.type}]"
else:
gateway_package = "dstack-gateway"
return [
"mkdir -p /home/ubuntu/dstack",
"python3 -m venv /home/ubuntu/dstack/blue",
"python3 -m venv /home/ubuntu/dstack/green",
f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}",
f"/home/ubuntu/dstack/blue/bin/pip install {wheel}",
f"/home/ubuntu/dstack/blue/bin/pip install --upgrade '{gateway_package}'",
"sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run",
]

Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,9 @@ def create_gateway(
machine_type="e2-medium",
accelerators=[],
spot=False,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
authorized_keys=[configuration.ssh_key_pub],
labels=labels,
tags=[gcp_resources.DSTACK_GATEWAY_TAG],
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.utils.tags import tags_validator


Expand Down Expand Up @@ -50,6 +51,10 @@ class GatewayConfiguration(CoreModel):
default: Annotated[bool, Field(description="Make the gateway default")] = False
backend: Annotated[BackendType, Field(description="The gateway backend")]
region: Annotated[str, Field(description="The gateway region")]
router: Annotated[
Optional[AnyRouterConfig],
Field(description="The router configuration"),
] = None
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
] = None
Expand Down Expand Up @@ -113,6 +118,7 @@ class GatewayComputeConfiguration(CoreModel):
ssh_key_pub: str
certificate: Optional[AnyGatewayCertificate] = None
tags: Optional[Dict[str, str]] = None
router: Optional[AnyRouterConfig] = None


class GatewayProvisioningData(CoreModel):
Expand Down
27 changes: 27 additions & 0 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from enum import Enum
from typing import Union

from pydantic import Field
from typing_extensions import Annotated, Literal

from dstack._internal.core.models.common import CoreModel


class RouterType(str, Enum):
SGLANG = "sglang"
VLLM = "vllm"


class SGLangRouterConfig(CoreModel):
type: Literal["sglang"] = "sglang"
policy: str = "cache_aware"


class VLLMRouterConfig(CoreModel):
type: Literal["vllm"] = "vllm"
policy: str = "cache_aware"


AnyRouterConfig = Annotated[
Union[SGLangRouterConfig, VLLMRouterConfig], Field(discriminator="type")
]
62 changes: 62 additions & 0 deletions src/dstack/_internal/proxy/gateway/model_routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Dict, List, Optional, Type

from dstack._internal.core.models.routers import AnyRouterConfig, RouterType
from dstack._internal.utils.logging import get_logger

from .base import Replica, Router, RouterContext

logger = get_logger(__name__)

"""This provides a registry of available router implementations."""

_ROUTER_CLASSES: List[Type[Router]] = []

try:
from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter

_ROUTER_CLASSES.append(SglangRouter)
logger.debug("Registered SglangRouter")
except ImportError as e:
logger.warning("SGLang router not available: %s", e)

_ROUTER_TYPE_TO_CLASS_MAP: Dict[RouterType, Type[Router]] = {}

for router_class in _ROUTER_CLASSES:
router_type_str = getattr(router_class, "TYPE", None)
if router_type_str is None:
logger.warning(f"Router class {router_class.__name__} missing TYPE attribute, skipping")
continue
router_type = RouterType(router_type_str)
_ROUTER_TYPE_TO_CLASS_MAP[router_type] = router_class

_AVAILABLE_ROUTER_TYPES = list(_ROUTER_TYPE_TO_CLASS_MAP.keys())


def get_router_class(router_type: RouterType) -> Optional[Type[Router]]:
"""Get the router class for a given router type."""
return _ROUTER_TYPE_TO_CLASS_MAP.get(router_type)


def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router:
"""Factory function to create a router instance from router configuration."""
router_type = RouterType(router.type)
router_class = get_router_class(router_type)

if router_class is None:
available_types = [rt.value for rt in _AVAILABLE_ROUTER_TYPES]
raise ValueError(
f"Router type '{router_type.value}' is not available. "
f"Available types: {available_types}"
)

# Router implementations may have different constructor signatures
# SglangRouter takes (router, context), others might differ
return router_class(router=router, context=context)


__all__ = [
"Router",
"RouterContext",
"Replica",
"get_router",
]
147 changes: 147 additions & 0 deletions src/dstack/_internal/proxy/gateway/model_routers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Literal, Optional

from pydantic import BaseModel

from dstack._internal.core.models.routers import AnyRouterConfig


class RouterContext(BaseModel):
"""Context for router initialization and configuration."""

class Config:
frozen = True

host: str = "127.0.0.1"
port: int = 3000
log_dir: Path = Path("./router_logs")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) These defaults here are no longer relevant after switching to per-service routers and dynamic port allocation.

I can suggest to remove the defaults to avoid misleading the devs. And then explicit RouterContext will also become required in Router.__init__

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved

log_level: Literal["debug", "info", "warning", "error"] = "info"


class Replica(BaseModel):
"""Represents a single replica (worker) endpoint managed by the router.
The model field identifies which model this replica serves.
In SGLang, model = model_id (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct").
"""

url: str # HTTP URL where the replica is accessible (e.g., "http://127.0.0.1:10001")
model: str # (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct")


class Router(ABC):
"""Abstract base class for router implementations (e.g., SGLang, vLLM).
A router manages the lifecycle of worker replicas and handles request routing.
Different router implementations may have different mechanisms for managing
replicas.
"""

def __init__(
self,
router: Optional[AnyRouterConfig] = None,
context: Optional[RouterContext] = None,
):
"""Initialize router with context.
Args:
router: Optional router configuration (implementation-specific)
context: Runtime context for the router (host, port, logging, etc.)
"""
self.context = context or RouterContext()

@abstractmethod
def start(self) -> None:
"""Start the router process.
Raises:
Exception: If the router fails to start.
"""
...

@abstractmethod
def stop(self) -> None:
"""Stop the router process.
Raises:
Exception: If the router fails to stop.
"""
...

@abstractmethod
def is_running(self) -> bool:
"""Check if the router is currently running and responding.
Returns:
True if the router is running and healthy, False otherwise.
"""
...

@abstractmethod
def register_replicas(
self, domain: str, num_replicas: int, model_id: Optional[str] = None
) -> List[Replica]:
"""Register replicas to a domain (allocate ports/URLs for workers).
Args:
domain: The domain name for this service.
num_replicas: The number of replicas to allocate for this domain.
model_id: Optional model identifier (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct").
Required only for routers that support IGW (Inference Gateway) mode for multi-model serving.
Returns:
List of Replica objects with allocated URLs and model_id set (if provided).
Raises:
Exception: If allocation fails.
"""
...

@abstractmethod
def unregister_replicas(self, domain: str) -> None:
"""Unregister replicas for a domain (remove model and unassign all its replicas).
Args:
domain: The domain name for this service.
Raises:
Exception: If removal fails or domain is not found.
"""
...

@abstractmethod
def add_replicas(self, replicas: List[Replica]) -> None:
"""Register replicas with the router (actual API calls to add workers).
Args:
replicas: The list of replicas to add to router.
Raises:
Exception: If adding replicas fails.
"""
...

@abstractmethod
def remove_replicas(self, replicas: List[Replica]) -> None:
"""Unregister replicas from the router (actual API calls to remove workers).
Args:
replicas: The list of replicas to remove from router.
Raises:
Exception: If removing replicas fails.
"""
...

@abstractmethod
def update_replicas(self, replicas: List[Replica]) -> None:
"""Update replicas for service, replacing the current set.
Args:
replicas: The new list of replicas for this service.
Raises:
Exception: If updating replicas fails.
"""
...
Loading