@@ -2,13 +2,18 @@ package api
22
33import (
44 "context"
5+ "errors"
56 "net/http"
67 "time"
78
89 "github.com/dstackai/dstack/runner/internal/log"
910 "github.com/gorilla/websocket"
1011)
1112
13+ type logsWsRequestParams struct {
14+ startTimestamp int64
15+ }
16+
1217var upgrader = websocket.Upgrader {
1318 CheckOrigin : func (r * http.Request ) bool {
1419 return true
@@ -20,18 +25,54 @@ func (s *Server) logsWsGetHandler(w http.ResponseWriter, r *http.Request) (inter
2025 if err != nil {
2126 return nil , err
2227 }
28+ requestParams , err := parseRequestParams (r )
29+ if err != nil {
30+ _ = conn .WriteMessage (
31+ websocket .CloseMessage ,
32+ websocket .FormatCloseMessage (websocket .CloseUnsupportedData , err .Error ()),
33+ )
34+ _ = conn .Close ()
35+ return nil , nil
36+ }
2337 // todo memorize clientId?
24- go s .streamJobLogs (conn )
38+ go s .streamJobLogs (r . Context (), conn , requestParams )
2539 return nil , nil
2640}
2741
28- func (s * Server ) streamJobLogs (conn * websocket.Conn ) {
29- currentPos := 0
42+ func parseRequestParams (r * http.Request ) (logsWsRequestParams , error ) {
43+ query := r .URL .Query ()
44+ startTimeStr := query .Get ("start_time" )
45+ var startTimestamp int64
46+ if startTimeStr != "" {
47+ t , err := time .Parse (time .RFC3339 , startTimeStr )
48+ if err != nil {
49+ return logsWsRequestParams {}, errors .New ("Failed to parse start_time value" )
50+ }
51+ startTimestamp = t .Unix ()
52+ }
53+ return logsWsRequestParams {startTimestamp : startTimestamp }, nil
54+ }
55+
56+ func (s * Server ) streamJobLogs (ctx context.Context , conn * websocket.Conn , params logsWsRequestParams ) {
3057 defer func () {
3158 _ = conn .WriteMessage (websocket .CloseMessage , nil )
3259 _ = conn .Close ()
3360 }()
34-
61+ currentPos := 0
62+ startTimestampMs := params .startTimestamp * 1000
63+ if startTimestampMs != 0 {
64+ // TODO: Replace currentPos linear search with binary search
65+ s .executor .RLock ()
66+ jobLogsWsHistory := s .executor .GetJobWsLogsHistory ()
67+ for _ , logEntry := range jobLogsWsHistory {
68+ if logEntry .Timestamp < startTimestampMs {
69+ currentPos += 1
70+ } else {
71+ break
72+ }
73+ }
74+ s .executor .RUnlock ()
75+ }
3576 for {
3677 s .executor .RLock ()
3778 jobLogsWsHistory := s .executor .GetJobWsLogsHistory ()
@@ -52,7 +93,7 @@ func (s *Server) streamJobLogs(conn *websocket.Conn) {
5293 for currentPos < len (jobLogsWsHistory ) {
5394 if err := conn .WriteMessage (websocket .BinaryMessage , jobLogsWsHistory [currentPos ].Message ); err != nil {
5495 s .executor .RUnlock ()
55- log .Error (context . TODO () , "Failed to write message" , "err" , err )
96+ log .Error (ctx , "Failed to write message" , "err" , err )
5697 return
5798 }
5899 currentPos ++
0 commit comments