Skip to content

Commit 2e3152c

Browse files
authored
Forbid running multinode tasks on non-cluster fleets (#3277)
* Forbid running multinode tasks on non-cluster fleets * Guarantee fleet cluster placement with concurrent provisioning * Fix tests * Fix fleet lock on sqlite * Simplify fleet selection code * Recommend AsyncExitStack * Use is_db_sqlite() and is_db_postgres()
1 parent f9576a6 commit 2e3152c

File tree

8 files changed

+336
-122
lines changed

8 files changed

+336
-122
lines changed

contributing/LOCKING.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,31 @@ Note that:
108108

109109
* This pattern works assuming that Postgres is using default isolation level Read Committed. By the time a transaction acquires the advisory lock, all other transactions that can take the name have committed, so their changes can be seen and a unique name is taken.
110110
* SQLite needs a commit before selecting taken names due to Snapshot Isolation as noted above.
111+
112+
**Use `AsyncExitStack`**
113+
114+
In-memory locking typically requires taking lock for long (until commit).
115+
Using lock context managers for in-memory locking is often hard because the lock is tied to a block:
116+
117+
```python
118+
if something:
119+
# Can't do this because the lock will be released before commit. How to lock?
120+
async with get_locker(get_db().dialect_name).lock_ctx(...):
121+
# ...
122+
# ...
123+
await session.commit()
124+
```
125+
126+
Use [`contextlib.AsyncExitStack`](https://docs.python.org/3/library/contextlib.html#contextlib.AsyncExitStack):
127+
128+
```python
129+
async with AsyncExitStack() as exit_stack:
130+
if something:
131+
# The lock will be released only on stack exit, so it's ok.
132+
await exit_stack.enter_async_context(
133+
get_locker(get_db().dialect_name).lock_ctx(...)
134+
)
135+
# ...
136+
# ...
137+
await session.commit()
138+
```

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 224 additions & 105 deletions
Large diffs are not rendered by default.

src/dstack/_internal/server/db.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,23 @@ async def new_func(*args, **kwargs):
103103
return new_func
104104

105105

106+
def is_db_sqlite() -> bool:
107+
return get_db().dialect_name == "sqlite"
108+
109+
110+
def is_db_postgres() -> bool:
111+
return get_db().dialect_name == "postgresql"
112+
113+
114+
async def sqlite_commit(session: AsyncSession):
115+
"""
116+
Commit an sqlite transaction.
117+
Should be used before taking locks in active sessions to see committed changes.
118+
"""
119+
if is_db_sqlite():
120+
await session.commit()
121+
122+
106123
def _run_alembic_upgrade(connection):
107124
alembic_cfg = config.Config()
108125
alembic_cfg.set_main_option("script_location", settings.ALEMBIC_MIGRATIONS_LOCATION)

src/dstack/_internal/server/services/fleets.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from dstack._internal.core.models.users import GlobalRole
4646
from dstack._internal.core.services import validate_dstack_resource_name
4747
from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models
48-
from dstack._internal.server.db import get_db
48+
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
4949
from dstack._internal.server.models import (
5050
FleetModel,
5151
InstanceModel,
@@ -675,14 +675,13 @@ async def _create_fleet(
675675
spec: FleetSpec,
676676
) -> Fleet:
677677
lock_namespace = f"fleet_names_{project.name}"
678-
if get_db().dialect_name == "sqlite":
678+
if is_db_sqlite():
679679
# Start new transaction to see committed changes after lock
680680
await session.commit()
681-
elif get_db().dialect_name == "postgresql":
681+
elif is_db_postgres():
682682
await session.execute(
683683
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
684684
)
685-
686685
lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
687686
async with lock:
688687
if spec.configuration.name is not None:

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from dstack._internal.core.services import validate_dstack_resource_name
4040
from dstack._internal.server import settings
41-
from dstack._internal.server.db import get_db
41+
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
4242
from dstack._internal.server.models import (
4343
GatewayComputeModel,
4444
GatewayModel,
@@ -148,14 +148,13 @@ async def create_gateway(
148148
)
149149

150150
lock_namespace = f"gateway_names_{project.name}"
151-
if get_db().dialect_name == "sqlite":
151+
if is_db_sqlite():
152152
# Start new transaction to see committed changes after lock
153153
await session.commit()
154-
elif get_db().dialect_name == "postgresql":
154+
elif is_db_postgres():
155155
await session.execute(
156156
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
157157
)
158-
159158
lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
160159
async with lock:
161160
if configuration.name is None:

src/dstack/_internal/server/services/runs.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from dstack._internal.core.services import validate_dstack_resource_name
5959
from dstack._internal.core.services.diff import diff_models
6060
from dstack._internal.server import settings
61-
from dstack._internal.server.db import get_db
61+
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
6262
from dstack._internal.server.models import (
6363
FleetModel,
6464
JobModel,
@@ -510,14 +510,13 @@ async def submit_run(
510510
)
511511

512512
lock_namespace = f"run_names_{project.name}"
513-
if get_db().dialect_name == "sqlite":
513+
if is_db_sqlite():
514514
# Start new transaction to see committed changes after lock
515515
await session.commit()
516-
elif get_db().dialect_name == "postgresql":
516+
elif is_db_postgres():
517517
await session.execute(
518518
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
519519
)
520-
521520
lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
522521
async with lock:
523522
# FIXME: delete_runs commits, so Postgres lock is released too early.

src/dstack/_internal/server/services/volumes.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
VolumeStatus,
2525
)
2626
from dstack._internal.core.services import validate_dstack_resource_name
27-
from dstack._internal.server.db import get_db
27+
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
2828
from dstack._internal.server.models import (
2929
InstanceModel,
3030
ProjectModel,
@@ -215,14 +215,13 @@ async def create_volume(
215215
_validate_volume_configuration(configuration)
216216

217217
lock_namespace = f"volume_names_{project.name}"
218-
if get_db().dialect_name == "sqlite":
218+
if is_db_sqlite():
219219
# Start new transaction to see committed changes after lock
220220
await session.commit()
221-
elif get_db().dialect_name == "postgresql":
221+
elif is_db_postgres():
222222
await session.execute(
223223
select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace)))
224224
)
225-
226225
lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace)
227226
async with lock:
228227
if configuration.name is not None:

src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dstack._internal.core.models.backends.base import BackendType
1010
from dstack._internal.core.models.common import NetworkMode
1111
from dstack._internal.core.models.configurations import TaskConfiguration
12-
from dstack._internal.core.models.fleets import FleetNodesSpec
12+
from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement
1313
from dstack._internal.core.models.health import HealthStatus
1414
from dstack._internal.core.models.instances import (
1515
InstanceAvailability,
@@ -546,6 +546,7 @@ async def test_assigns_multi_node_job_to_shared_instance(self, test_db, session:
546546
)
547547
offer = get_instance_offer_with_availability(gpu_count=8, cpu_count=64, memory_gib=128)
548548
fleet_spec = get_fleet_spec()
549+
fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER
549550
fleet_spec.configuration.nodes = FleetNodesSpec(min=1, target=1, max=None)
550551
fleet = await create_fleet(session=session, project=project, spec=fleet_spec)
551552
instance = await create_instance(
@@ -1189,6 +1190,59 @@ async def test_provisions_compute_group(self, test_db, session: AsyncSession):
11891190
res = await session.execute(select(ComputeGroupModel))
11901191
assert res.scalar() is not None
11911192

1193+
@pytest.mark.asyncio
1194+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1195+
async def test_provisioning_master_job_respects_cluster_placement_in_non_empty_fleet(
1196+
self, test_db, session: AsyncSession
1197+
):
1198+
project = await create_project(session)
1199+
user = await create_user(session)
1200+
repo = await create_repo(session=session, project_id=project.id)
1201+
fleet_spec = get_fleet_spec()
1202+
fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER
1203+
fleet_spec.configuration.nodes = FleetNodesSpec(min=0, target=0, max=None)
1204+
fleet = await create_fleet(session=session, project=project, spec=fleet_spec)
1205+
await create_instance(
1206+
session=session,
1207+
project=project,
1208+
fleet=fleet,
1209+
status=InstanceStatus.BUSY,
1210+
backend=BackendType.AWS,
1211+
job_provisioning_data=get_job_provisioning_data(region="eu-west-1"),
1212+
)
1213+
configuration = TaskConfiguration(image="debian", nodes=2)
1214+
run_spec = get_run_spec(run_name="run", repo_id=repo.name, configuration=configuration)
1215+
run = await create_run(
1216+
session=session,
1217+
run_name="run",
1218+
project=project,
1219+
repo=repo,
1220+
user=user,
1221+
run_spec=run_spec,
1222+
)
1223+
job = await create_job(
1224+
session=session,
1225+
run=run,
1226+
fleet=fleet,
1227+
instance_assigned=True,
1228+
)
1229+
with patch("dstack._internal.server.services.backends.get_project_backends") as m:
1230+
backend_mock = Mock()
1231+
m.return_value = [backend_mock]
1232+
backend_mock.TYPE = BackendType.AWS
1233+
offer1 = get_instance_offer_with_availability(region="eu-west-2")
1234+
offer2 = get_instance_offer_with_availability(region="eu-west-1")
1235+
backend_mock.compute.return_value.get_offers.return_value = [offer1, offer2]
1236+
backend_mock.compute.return_value.run_job.return_value = get_job_provisioning_data()
1237+
await process_submitted_jobs()
1238+
m.assert_called_once()
1239+
backend_mock.compute.return_value.get_offers.assert_called_once()
1240+
backend_mock.compute.return_value.run_job.assert_called_once()
1241+
selected_offer = backend_mock.compute.return_value.run_job.call_args[0][2]
1242+
assert selected_offer.region == "eu-west-1"
1243+
await session.refresh(job)
1244+
assert job.status == JobStatus.PROVISIONING
1245+
11921246

11931247
@pytest.mark.parametrize(
11941248
["job_network_mode", "blocks", "multinode", "network_mode", "constraints_are_set"],

0 commit comments

Comments
 (0)