Skip to content

Commit d2a8870

Browse files
authored
Support --since arg for dstack attach --logs command (#3268)
1 parent 17cc2bd commit d2a8870

File tree

7 files changed

+85
-50
lines changed

7 files changed

+85
-50
lines changed

runner/internal/runner/api/ws.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@ package api
22

33
import (
44
"context"
5+
"errors"
56
"net/http"
67
"time"
78

89
"github.com/dstackai/dstack/runner/internal/log"
910
"github.com/gorilla/websocket"
1011
)
1112

13+
type logsWsRequestParams struct {
14+
startTimestamp int64
15+
}
16+
1217
var upgrader = websocket.Upgrader{
1318
CheckOrigin: func(r *http.Request) bool {
1419
return true
@@ -20,18 +25,54 @@ func (s *Server) logsWsGetHandler(w http.ResponseWriter, r *http.Request) (inter
2025
if err != nil {
2126
return nil, err
2227
}
28+
requestParams, err := parseRequestParams(r)
29+
if err != nil {
30+
_ = conn.WriteMessage(
31+
websocket.CloseMessage,
32+
websocket.FormatCloseMessage(websocket.CloseUnsupportedData, err.Error()),
33+
)
34+
_ = conn.Close()
35+
return nil, nil
36+
}
2337
// todo memorize clientId?
24-
go s.streamJobLogs(conn)
38+
go s.streamJobLogs(r.Context(), conn, requestParams)
2539
return nil, nil
2640
}
2741

28-
func (s *Server) streamJobLogs(conn *websocket.Conn) {
29-
currentPos := 0
42+
func parseRequestParams(r *http.Request) (logsWsRequestParams, error) {
43+
query := r.URL.Query()
44+
startTimeStr := query.Get("start_time")
45+
var startTimestamp int64
46+
if startTimeStr != "" {
47+
t, err := time.Parse(time.RFC3339, startTimeStr)
48+
if err != nil {
49+
return logsWsRequestParams{}, errors.New("Failed to parse start_time value")
50+
}
51+
startTimestamp = t.Unix()
52+
}
53+
return logsWsRequestParams{startTimestamp: startTimestamp}, nil
54+
}
55+
56+
func (s *Server) streamJobLogs(ctx context.Context, conn *websocket.Conn, params logsWsRequestParams) {
3057
defer func() {
3158
_ = conn.WriteMessage(websocket.CloseMessage, nil)
3259
_ = conn.Close()
3360
}()
34-
61+
currentPos := 0
62+
startTimestampMs := params.startTimestamp * 1000
63+
if startTimestampMs != 0 {
64+
// TODO: Replace currentPos linear search with binary search
65+
s.executor.RLock()
66+
jobLogsWsHistory := s.executor.GetJobWsLogsHistory()
67+
for _, logEntry := range jobLogsWsHistory {
68+
if logEntry.Timestamp < startTimestampMs {
69+
currentPos += 1
70+
} else {
71+
break
72+
}
73+
}
74+
s.executor.RUnlock()
75+
}
3576
for {
3677
s.executor.RLock()
3778
jobLogsWsHistory := s.executor.GetJobWsLogsHistory()
@@ -52,7 +93,7 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) {
5293
for currentPos < len(jobLogsWsHistory) {
5394
if err := conn.WriteMessage(websocket.BinaryMessage, jobLogsWsHistory[currentPos].Message); err != nil {
5495
s.executor.RUnlock()
55-
log.Error(context.TODO(), "Failed to write message", "err", err)
96+
log.Error(ctx, "Failed to write message", "err", err)
5697
return
5798
}
5899
currentPos++

runner/internal/schemas/schemas.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ type JobStateEvent struct {
1616

1717
type LogEvent struct {
1818
Message []byte `json:"message"`
19-
Timestamp int64 `json:"timestamp"`
19+
Timestamp int64 `json:"timestamp"` // milliseconds
2020
}
2121

2222
type SubmitBody struct {

src/dstack/_internal/cli/commands/attach.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
get_run_exit_code,
1212
print_finished_message,
1313
)
14-
from dstack._internal.cli.utils.common import console
14+
from dstack._internal.cli.utils.common import console, get_start_time
1515
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT
1616
from dstack._internal.core.errors import CLIError
1717
from dstack._internal.utils.common import get_or_error
@@ -61,6 +61,14 @@ def _register(self):
6161
type=int,
6262
default=0,
6363
)
64+
self._parser.add_argument(
65+
"--since",
66+
help=(
67+
"Show only logs newer than the specified date."
68+
" Can be a duration (e.g. 10s, 5m, 1d) or an RFC 3339 string (e.g. 2023-09-24T15:30:00Z)."
69+
),
70+
type=str,
71+
)
6472
self._parser.add_argument("run_name").completer = RunNameCompleter() # type: ignore[attr-defined]
6573

6674
def _command(self, args: argparse.Namespace):
@@ -86,7 +94,9 @@ def _command(self, args: argparse.Namespace):
8694
job_num=args.job,
8795
)
8896
if args.logs:
97+
start_time = get_start_time(args.since)
8998
logs = run.logs(
99+
start_time=start_time,
90100
replica_num=args.replica,
91101
job_num=args.job,
92102
)

src/dstack/_internal/cli/commands/logs.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import argparse
22
import sys
3-
from datetime import datetime
4-
from typing import Optional
53

64
from dstack._internal.cli.commands import APIBaseCommand
75
from dstack._internal.cli.services.completion import RunNameCompleter
6+
from dstack._internal.cli.utils.common import get_start_time
87
from dstack._internal.core.errors import CLIError
9-
from dstack._internal.utils.common import parse_since
108
from dstack._internal.utils.logging import get_logger
119

1210
logger = get_logger(__name__)
@@ -49,7 +47,7 @@ def _command(self, args: argparse.Namespace):
4947
if run is None:
5048
raise CLIError(f"Run {args.run_name} not found")
5149

52-
start_time = _get_start_time(args.since)
50+
start_time = get_start_time(args.since)
5351
logs = run.logs(
5452
start_time=start_time,
5553
diagnose=args.diagnose,
@@ -62,12 +60,3 @@ def _command(self, args: argparse.Namespace):
6260
sys.stdout.buffer.flush()
6361
except KeyboardInterrupt:
6462
pass
65-
66-
67-
def _get_start_time(since: Optional[str]) -> Optional[datetime]:
68-
if since is None:
69-
return None
70-
try:
71-
return parse_since(since)
72-
except ValueError as e:
73-
raise CLIError(e.args[0])

src/dstack/_internal/cli/utils/common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from datetime import datetime, timezone
33
from pathlib import Path
4-
from typing import Any, Dict, Union
4+
from typing import Any, Dict, Optional, Union
55

66
from rich.console import Console
77
from rich.prompt import Confirm
@@ -11,7 +11,7 @@
1111
from dstack._internal import settings
1212
from dstack._internal.cli.utils.rich import DstackRichHandler
1313
from dstack._internal.core.errors import CLIError, DstackError
14-
from dstack._internal.utils.common import get_dstack_dir
14+
from dstack._internal.utils.common import get_dstack_dir, parse_since
1515

1616
_colors = {
1717
"secondary": "grey58",
@@ -110,3 +110,12 @@ def warn(message: str):
110110
# Additional blank line for better visibility if there are more than one warning
111111
message = f"{message}\n"
112112
console.print(f"[warning][bold]{message}[/]")
113+
114+
115+
def get_start_time(since: Optional[str]) -> Optional[datetime]:
116+
if since is None:
117+
return None
118+
try:
119+
return parse_since(since)
120+
except ValueError as e:
121+
raise CLIError(e.args[0])

src/dstack/_internal/utils/common.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from typing_extensions import ParamSpec
1414

15+
from dstack._internal.core.models.common import Duration
16+
1517
P = ParamSpec("P")
1618
R = TypeVar("R")
1719

@@ -150,20 +152,16 @@ def parse_since(value: str) -> datetime:
150152
or a duration (e.g. 10s, 5m, 1d) between the timestamp and now.
151153
"""
152154
try:
153-
seconds = parse_pretty_duration(value)
155+
seconds = Duration.parse(value)
154156
return get_current_datetime() - timedelta(seconds=seconds)
155157
except ValueError:
156158
pass
157159
try:
158160
res = datetime.fromisoformat(value)
159161
except ValueError:
160-
pass
162+
raise ValueError("Invalid datetime format")
161163
else:
162164
return check_time_offset_aware(res)
163-
try:
164-
return datetime.fromtimestamp(int(value), tz=timezone.utc)
165-
except Exception:
166-
raise ValueError("Invalid datetime format")
167165

168166

169167
def check_time_offset_aware(time: datetime) -> datetime:
@@ -172,22 +170,6 @@ def check_time_offset_aware(time: datetime) -> datetime:
172170
return time
173171

174172

175-
def parse_pretty_duration(duration: str) -> int:
176-
regex = re.compile(r"(?P<amount>\d+)(?P<unit>s|m|h|d|w)$")
177-
re_match = regex.match(duration)
178-
if not re_match:
179-
raise ValueError(f"Cannot parse the duration {duration}")
180-
amount, unit = int(re_match.group("amount")), re_match.group("unit")
181-
multiplier = {
182-
"s": 1,
183-
"m": 60,
184-
"h": 3600,
185-
"d": 24 * 3600,
186-
"w": 7 * 24 * 3600,
187-
}[unit]
188-
return amount * multiplier
189-
190-
191173
DURATION_UNITS_DESC = [
192174
("w", 7 * 24 * 3600),
193175
("d", 24 * 3600),

src/dstack/api/_public/runs.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from datetime import datetime
1111
from pathlib import Path
1212
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
13-
from urllib.parse import urlparse
13+
from urllib.parse import urlencode, urlparse
1414

1515
from websocket import WebSocketApp
1616

@@ -136,9 +136,7 @@ def service_model(self) -> Optional["ServiceModel"]:
136136
),
137137
)
138138

139-
def _attached_logs(
140-
self,
141-
) -> Iterable[bytes]:
139+
def _attached_logs(self, start_time: Optional[datetime] = None) -> Iterable[bytes]:
142140
q = queue.Queue()
143141
_done = object()
144142

@@ -150,8 +148,14 @@ def ws_thread():
150148
logger.debug("WebSocket logs are done for %s", self.name)
151149
q.put(_done)
152150

151+
url = f"ws://localhost:{self.ports[DSTACK_RUNNER_HTTP_PORT]}/logs_ws"
152+
query_params = {}
153+
if start_time is not None:
154+
query_params["start_time"] = start_time.isoformat()
155+
if query_params:
156+
url = f"{url}?{urlencode(query_params)}"
153157
ws = WebSocketApp(
154-
f"ws://localhost:{self.ports[DSTACK_RUNNER_HTTP_PORT]}/logs_ws",
158+
url=url,
155159
on_open=lambda _: logger.debug("WebSocket logs are connected to %s", self.name),
156160
on_close=lambda _, status_code, msg: logger.debug(
157161
"WebSocket logs are disconnected. status_code: %s; message: %s",
@@ -215,7 +219,7 @@ def logs(
215219
Log messages.
216220
"""
217221
if diagnose is False and self._ssh_attach is not None:
218-
yield from self._attached_logs()
222+
yield from self._attached_logs(start_time=start_time)
219223
else:
220224
job = self._find_job(replica_num=replica_num, job_num=job_num)
221225
if job is None:

0 commit comments

Comments
 (0)