Skip to content

Commit a795f8a

Browse files
authored
Fix working_dir compatibility with pre-0.19.27 clients (#3231)
Fixes: #3225
1 parent fc5aecd commit a795f8a

File tree

8 files changed

+114
-41
lines changed

8 files changed

+114
-41
lines changed

src/dstack/_internal/server/app.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from concurrent.futures import ThreadPoolExecutor
66
from contextlib import asynccontextmanager
77
from pathlib import Path
8-
from typing import Awaitable, Callable, List
8+
from typing import Awaitable, Callable, List, Optional
99

1010
import sentry_sdk
1111
from fastapi import FastAPI, Request, Response, status
@@ -62,6 +62,7 @@
6262
CustomORJSONResponse,
6363
check_client_server_compatibility,
6464
error_detail,
65+
get_client_version,
6566
get_server_client_error_details,
6667
)
6768
from dstack._internal.settings import DSTACK_VERSION
@@ -319,8 +320,19 @@ async def check_client_version(request: Request, call_next):
319320
or request.url.path in _NO_API_VERSION_CHECK_ROUTES
320321
):
321322
return await call_next(request)
323+
try:
324+
client_version = get_client_version(request)
325+
except ValueError as e:
326+
return CustomORJSONResponse(
327+
status_code=status.HTTP_400_BAD_REQUEST,
328+
content={"detail": [error_detail(str(e))]},
329+
)
330+
client_release: Optional[tuple[int, ...]] = None
331+
if client_version is not None:
332+
client_release = client_version.release
333+
request.state.client_release = client_release
322334
response = check_client_server_compatibility(
323-
client_version=request.headers.get("x-api-version"),
335+
client_version=client_version,
324336
server_version=DSTACK_VERSION,
325337
)
326338
if response is not None:

src/dstack/_internal/server/routers/runs.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import List, Tuple
1+
from typing import Annotated, List, Optional, Tuple, cast
22

3-
from fastapi import APIRouter, Depends
3+
from fastapi import APIRouter, Depends, Request
44
from sqlalchemy.ext.asyncio import AsyncSession
55

66
from dstack._internal.core.errors import ResourceNotExistsError
@@ -35,6 +35,11 @@
3535
)
3636

3737

38+
def use_legacy_default_working_dir(request: Request) -> bool:
39+
client_release = cast(Optional[tuple[int, ...]], request.state.client_release)
40+
return client_release is not None and client_release < (0, 19, 27)
41+
42+
3843
@root_router.post(
3944
"/list",
4045
response_model=List[Run],
@@ -103,8 +108,9 @@ async def get_run(
103108
)
104109
async def get_plan(
105110
body: GetRunPlanRequest,
106-
session: AsyncSession = Depends(get_session),
107-
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
111+
session: Annotated[AsyncSession, Depends(get_session)],
112+
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
113+
legacy_default_working_dir: Annotated[bool, Depends(use_legacy_default_working_dir)],
108114
):
109115
"""
110116
Returns a run plan for the given run spec.
@@ -119,6 +125,7 @@ async def get_plan(
119125
user=user,
120126
run_spec=body.run_spec,
121127
max_offers=body.max_offers,
128+
legacy_default_working_dir=legacy_default_working_dir,
122129
)
123130
return CustomORJSONResponse(run_plan)
124131

@@ -129,8 +136,9 @@ async def get_plan(
129136
)
130137
async def apply_plan(
131138
body: ApplyRunPlanRequest,
132-
session: AsyncSession = Depends(get_session),
133-
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
139+
session: Annotated[AsyncSession, Depends(get_session)],
140+
user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())],
141+
legacy_default_working_dir: Annotated[bool, Depends(use_legacy_default_working_dir)],
134142
):
135143
"""
136144
Creates a new run or updates an existing run.
@@ -148,6 +156,7 @@ async def apply_plan(
148156
project=project,
149157
plan=body.plan,
150158
force=body.force,
159+
legacy_default_working_dir=legacy_default_working_dir,
151160
)
152161
)
153162

src/dstack/_internal/server/services/runs.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535
from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData
3636
from dstack._internal.core.models.runs import (
37+
LEGACY_REPO_DIR,
3738
ApplyRunPlanInput,
3839
Job,
3940
JobPlan,
@@ -308,6 +309,7 @@ async def get_plan(
308309
user: UserModel,
309310
run_spec: RunSpec,
310311
max_offers: Optional[int],
312+
legacy_default_working_dir: bool = False,
311313
) -> RunPlan:
312314
# Spec must be copied by parsing to calculate merged_profile
313315
effective_run_spec = RunSpec.parse_obj(run_spec.dict())
@@ -317,7 +319,11 @@ async def get_plan(
317319
spec=effective_run_spec,
318320
)
319321
effective_run_spec = RunSpec.parse_obj(effective_run_spec.dict())
320-
_validate_run_spec_and_set_defaults(user, effective_run_spec)
322+
_validate_run_spec_and_set_defaults(
323+
user=user,
324+
run_spec=effective_run_spec,
325+
legacy_default_working_dir=legacy_default_working_dir,
326+
)
321327

322328
profile = effective_run_spec.merged_profile
323329
creation_policy = profile.creation_policy
@@ -413,6 +419,7 @@ async def apply_plan(
413419
project: ProjectModel,
414420
plan: ApplyRunPlanInput,
415421
force: bool,
422+
legacy_default_working_dir: bool = False,
416423
) -> Run:
417424
run_spec = plan.run_spec
418425
run_spec = await apply_plugin_policies(
@@ -422,7 +429,9 @@ async def apply_plan(
422429
)
423430
# Spec must be copied by parsing to calculate merged_profile
424431
run_spec = RunSpec.parse_obj(run_spec.dict())
425-
_validate_run_spec_and_set_defaults(user, run_spec)
432+
_validate_run_spec_and_set_defaults(
433+
user=user, run_spec=run_spec, legacy_default_working_dir=legacy_default_working_dir
434+
)
426435
if run_spec.run_name is None:
427436
return await submit_run(
428437
session=session,
@@ -985,7 +994,9 @@ def _get_job_submission_cost(job_submission: JobSubmission) -> float:
985994
return job_submission.job_provisioning_data.price * duration_hours
986995

987996

988-
def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
997+
def _validate_run_spec_and_set_defaults(
998+
user: UserModel, run_spec: RunSpec, legacy_default_working_dir: bool = False
999+
):
9891000
# This function may set defaults for null run_spec values,
9901001
# although most defaults are resolved when building job_spec
9911002
# so that we can keep both the original user-supplied value (null in run_spec)
@@ -1040,6 +1051,8 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
10401051
run_spec.ssh_key_pub = user.ssh_public_key
10411052
else:
10421053
raise ServerClientError("ssh_key_pub must be set if the user has no ssh_public_key")
1054+
if run_spec.configuration.working_dir is None and legacy_default_working_dir:
1055+
run_spec.configuration.working_dir = LEGACY_REPO_DIR
10431056

10441057

10451058
_UPDATABLE_SPEC_FIELDS = ["configuration_path", "configuration"]

src/dstack/_internal/server/utils/routers.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import Any, Dict, List, Optional
22

33
import orjson
4+
import packaging.version
45
from fastapi import HTTPException, Request, Response, status
5-
from packaging import version
66

77
from dstack._internal.core.errors import ServerClientError, ServerClientErrorCode
88
from dstack._internal.core.models.common import CoreModel
99
from dstack._internal.utils.json_utils import get_orjson_default_options, orjson_default
10+
from dstack._internal.utils.version import parse_version
1011

1112

1213
class CustomORJSONResponse(Response):
@@ -122,8 +123,15 @@ def get_request_size(request: Request) -> int:
122123
return int(request.headers["content-length"])
123124

124125

126+
def get_client_version(request: Request) -> Optional[packaging.version.Version]:
127+
version = request.headers.get("x-api-version")
128+
if version is None:
129+
return None
130+
return parse_version(version)
131+
132+
125133
def check_client_server_compatibility(
126-
client_version: Optional[str],
134+
client_version: Optional[packaging.version.Version],
127135
server_version: Optional[str],
128136
) -> Optional[CustomORJSONResponse]:
129137
"""
@@ -132,28 +140,18 @@ def check_client_server_compatibility(
132140
"""
133141
if client_version is None or server_version is None:
134142
return None
135-
parsed_server_version = version.parse(server_version)
136-
# latest allows client to bypass compatibility check (e.g. frontend)
137-
if client_version == "latest":
143+
parsed_server_version = parse_version(server_version)
144+
if parsed_server_version is None:
138145
return None
139-
try:
140-
parsed_client_version = version.parse(client_version)
141-
except version.InvalidVersion:
142-
return CustomORJSONResponse(
143-
status_code=status.HTTP_400_BAD_REQUEST,
144-
content={
145-
"detail": get_server_client_error_details(
146-
ServerClientError("Bad API version specified")
147-
)
148-
},
149-
)
150146
# We preserve full client backward compatibility across patch releases.
151147
# Server is always partially backward-compatible (so no check).
152-
if parsed_client_version > parsed_server_version and (
153-
parsed_client_version.major > parsed_server_version.major
154-
or parsed_client_version.minor > parsed_server_version.minor
148+
if client_version > parsed_server_version and (
149+
client_version.major > parsed_server_version.major
150+
or client_version.minor > parsed_server_version.minor
155151
):
156-
return error_incompatible_versions(client_version, server_version, ask_cli_update=False)
152+
return error_incompatible_versions(
153+
str(client_version), server_version, ask_cli_update=False
154+
)
157155
return None
158156

159157

src/dstack/_internal/settings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22

33
from dstack import version
4+
from dstack._internal.utils.version import parse_version
45

56
DSTACK_VERSION = os.getenv("DSTACK_VERSION", version.__version__)
6-
if DSTACK_VERSION == "0.0.0":
7+
if parse_version(DSTACK_VERSION) is None:
78
# The build backend (hatching) requires not None for versions,
89
# but the code currently treats None as dev version.
910
# TODO: update the code to treat 0.0.0 as dev version.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import Optional
2+
3+
import packaging.version
4+
5+
6+
def parse_version(version_string: str) -> Optional[packaging.version.Version]:
7+
"""
8+
Returns a `packaging.version.Version` instance or `None` if the version is dev/latest.
9+
10+
Values parsed as the dev/latest version:
11+
* the "latest" literal
12+
* any "0.0.0" release, e.g., "0.0.0", "0.0.0a1", "0.0.0.dev0"
13+
"""
14+
if version_string == "latest":
15+
return None
16+
try:
17+
version = packaging.version.parse(version_string)
18+
except packaging.version.InvalidVersion as e:
19+
raise ValueError(f"Invalid version: {version_string}") from e
20+
if version.release == (0, 0, 0):
21+
return None
22+
return version

src/tests/_internal/server/utils/test_routers.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from typing import Optional
22

3+
import packaging.version
34
import pytest
45

56
from dstack._internal.server.utils.routers import check_client_server_compatibility
67

78

89
class TestCheckClientServerCompatibility:
9-
@pytest.mark.parametrize("client_version", ["12.12.12", None])
10-
def test_returns_none_if_server_version_is_none(self, client_version: Optional[str]):
10+
@pytest.mark.parametrize("client_version", [packaging.version.parse("12.12.12"), None])
11+
def test_returns_none_if_server_version_is_none(
12+
self, client_version: Optional[packaging.version.Version]
13+
):
1114
assert (
1215
check_client_server_compatibility(
1316
client_version=client_version,
@@ -27,12 +30,10 @@ def test_returns_none_if_server_version_is_none(self, client_version: Optional[s
2730
("1.0.5", "1.0.6"),
2831
],
2932
)
30-
def test_returns_none_if_compatible(
31-
self, client_version: Optional[str], server_version: Optional[str]
32-
):
33+
def test_returns_none_if_compatible(self, client_version: str, server_version: str):
3334
assert (
3435
check_client_server_compatibility(
35-
client_version=client_version,
36+
client_version=packaging.version.parse(client_version),
3637
server_version=server_version,
3738
)
3839
is None
@@ -46,10 +47,10 @@ def test_returns_none_if_compatible(
4647
],
4748
)
4849
def test_returns_error_if_client_version_larger(
49-
self, client_version: Optional[str], server_version: Optional[str]
50+
self, client_version: str, server_version: str
5051
):
5152
res = check_client_server_compatibility(
52-
client_version=client_version,
53+
client_version=packaging.version.parse(client_version),
5354
server_version=server_version,
5455
)
5556
assert res is not None
@@ -63,7 +64,7 @@ def test_returns_error_if_client_version_larger(
6364
)
6465
def test_returns_none_if_client_version_is_latest(self, server_version: Optional[str]):
6566
res = check_client_server_compatibility(
66-
client_version="latest",
67+
client_version=None,
6768
server_version=server_version,
6869
)
6970
assert res is None
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import packaging.version
2+
import pytest
3+
4+
from dstack._internal.utils.version import parse_version
5+
6+
7+
class TestParseVersion:
8+
@pytest.mark.parametrize("version", ["0.0.0", "0.0.0.dev0", "0.0.0alpha", "latest"])
9+
def test_latest(self, version: str):
10+
assert parse_version(version) is None
11+
12+
def test_release(self):
13+
assert parse_version("0.19.27") == packaging.version.parse("0.19.27")
14+
15+
def test_error_invalid_version(self):
16+
with pytest.raises(ValueError, match=r"Invalid version: 0\.0invalid"):
17+
parse_version("0.0invalid")

0 commit comments

Comments
 (0)