From 3954cd738d54a4c491cdb5b865704cd5b41b759b Mon Sep 17 00:00:00 2001 From: Henry Barreto Date: Mon, 9 Jun 2025 18:02:40 -0300 Subject: [PATCH 1/3] feat(api): add sessions events listing to store --- api/routes/routes.go | 1 + api/routes/session.go | 45 +++++- api/routes/session_test.go | 198 ++++++++++++++++++++++++- api/services/mocks/services.go | 73 ++++++--- api/services/session.go | 14 +- api/services/session_test.go | 264 +++++++++++++++++++++++++++++++-- api/store/mocks/store.go | 22 +-- api/store/mongo/session.go | 86 +++++++++-- api/store/session.go | 2 +- pkg/api/requests/session.go | 13 +- 10 files changed, 646 insertions(+), 72 deletions(-) diff --git a/api/routes/routes.go b/api/routes/routes.go index b35c75de349..228c7b785d1 100644 --- a/api/routes/routes.go +++ b/api/routes/routes.go @@ -132,6 +132,7 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo { publicAPI.GET(GetSessionsURL, routesmiddleware.Authorize(gateway.Handler(handler.GetSessionList))) publicAPI.GET(GetSessionURL, routesmiddleware.Authorize(gateway.Handler(handler.GetSession))) + publicAPI.GET(ListEventsSessionsURL, routesmiddleware.Authorize(gateway.Handler(handler.ListEventsSession))) publicAPI.GET(GetStatsURL, routesmiddleware.Authorize(gateway.Handler(handler.GetStats))) publicAPI.GET(GetSystemInfoURL, gateway.Handler(handler.GetSystemInfo)) diff --git a/api/routes/session.go b/api/routes/session.go index 82734a3ba51..a5d71247180 100644 --- a/api/routes/session.go +++ b/api/routes/session.go @@ -13,13 +13,14 @@ import ( ) const ( - GetSessionsURL = "/sessions" - GetSessionURL = "/sessions/:uid" - UpdateSessionURL = "/sessions/:uid" - CreateSessionURL = "/sessions" - FinishSessionURL = "/sessions/:uid/finish" - KeepAliveSessionURL = "/sessions/:uid/keepalive" - EventsSessionsURL = "/sessions/:uid/events" + GetSessionsURL = "/sessions" + GetSessionURL = "/sessions/:uid" + UpdateSessionURL = "/sessions/:uid" + CreateSessionURL = "/sessions" + FinishSessionURL = "/sessions/:uid/finish" + KeepAliveSessionURL = "/sessions/:uid/keepalive" + EventsSessionsURL = "/sessions/:uid/events" + ListEventsSessionsURL = "/sessions/:uid/events" ) const ( @@ -164,7 +165,7 @@ func (h *Handler) EventSession(c gateway.Context) error { return err } - if err := h.service.EventSession(c.Ctx(), models.UID(req.UID), &models.SessionEvent{ + if err := h.service.SaveEventSession(c.Ctx(), models.UID(req.UID), &models.SessionEvent{ Session: req.UID, Type: models.SessionEventType(r.Type), Timestamp: r.Timestamp, @@ -175,3 +176,31 @@ func (h *Handler) EventSession(c gateway.Context) error { } } } + +func (h *Handler) ListEventsSession(c gateway.Context) error { + req := new(requests.SessionListEvents) + + if err := c.Bind(req); err != nil { + return err + } + + req.Paginator.Normalize() + req.Sorter.Normalize() + + if err := req.Filters.Unmarshal(); err != nil { + return err + } + + if err := c.Validate(req); err != nil { + return err + } + + events, counter, err := h.service.ListEventsSession(c.Ctx(), models.UID(req.UID), req.Paginator, req.Filters, req.Sorter) + if err != nil { + return err + } + + c.Response().Header().Set("X-Total-Count", strconv.Itoa(counter)) + + return c.JSON(http.StatusOK, events) +} diff --git a/api/routes/session_test.go b/api/routes/session_test.go index 7b8dd0a9419..205f1323050 100644 --- a/api/routes/session_test.go +++ b/api/routes/session_test.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" + "strconv" "strings" "testing" @@ -399,7 +401,7 @@ func TestEventSession(t *testing.T) { webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once() - mock.On("EventSession", gomock.Anything, models.UID(uid), gomock.Anything). + mock.On("SaveEventSession", gomock.Anything, models.UID(uid), gomock.Anything). Return(errors.New("not able record")).Once() }, expected: http.StatusInternalServerError, @@ -422,7 +424,7 @@ func TestEventSession(t *testing.T) { webSocketUpgraderMock.On("Upgrade", gomock.Anything, gomock.Anything).Return(conn, nil).Once() - mock.On("EventSession", gomock.Anything, models.UID(uid), + mock.On("SaveEventSession", gomock.Anything, models.UID(uid), gomock.Anything).Return(nil).Once() conn.On("ReadJSON", gomock.Anything).Return(&websocket.CloseError{ @@ -468,3 +470,195 @@ func TestEventSession(t *testing.T) { }) } } + +func TestListEventsSession(t *testing.T) { + mock := new(mocks.Service) + + cases := []struct { + title string + req *requests.SessionListEvents + requiredMocks func() + expectedStatus int + expectedCounter string + expectedBody string + }{ + { + title: "fails to list session's events when input data is invalid", + req: &requests.SessionListEvents{ + UID: "", + Paginator: query.Paginator{}, + Sorter: query.Sorter{}, + Filters: query.Filters{}, + }, + requiredMocks: func() {}, + expectedStatus: http.StatusBadRequest, + expectedCounter: "", + expectedBody: "", + }, + { + title: "fails to list session's events when cannot validate input params", + req: &requests.SessionListEvents{ + UID: "", + Paginator: query.Paginator{Page: 1, PerPage: 10}, + Sorter: query.Sorter{By: "name", Order: "asc"}, + Filters: query.Filters{}, + }, + requiredMocks: func() {}, + expectedStatus: http.StatusBadRequest, + expectedCounter: "", + expectedBody: "", + }, + { + title: "fails to list session's events when service fails because session doesn't exist", + req: &requests.SessionListEvents{ + UID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + Paginator: query.Paginator{Page: 1, PerPage: 10}, + Sorter: query.Sorter{By: "name", Order: "asc"}, + Filters: query.Filters{}, + }, + requiredMocks: func() { + mock. + On("ListEventsSession", + gomock.Anything, + models.UID("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + gomock.Anything, + gomock.Anything, + gomock.Anything, + ). + Return(nil, 0, svc.ErrSessionNotFound). + Once() + }, + expectedStatus: http.StatusNotFound, + expectedCounter: "", + expectedBody: "", + }, + { + title: "fails to list session's events when service fails", + req: &requests.SessionListEvents{ + UID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + Paginator: query.Paginator{Page: 1, PerPage: 10}, + Sorter: query.Sorter{By: "name", Order: "asc"}, + Filters: query.Filters{}, + }, + requiredMocks: func() { + mock. + On("ListEventsSession", + gomock.Anything, + models.UID("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + gomock.Anything, + gomock.Anything, + gomock.Anything, + ). + Return(nil, 0, errors.New("")). + Once() + }, + expectedStatus: http.StatusInternalServerError, + expectedCounter: "", + expectedBody: "", + }, + { + title: "success to list session's events when it is empty", + req: &requests.SessionListEvents{ + UID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + Paginator: query.Paginator{Page: 1, PerPage: 10}, + Sorter: query.Sorter{By: "name", Order: "asc"}, + Filters: query.Filters{}, + }, + requiredMocks: func() { + mock. + On("ListEventsSession", + gomock.Anything, + models.UID("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + gomock.Anything, + gomock.Anything, + gomock.Anything, + ). + Return([]models.SessionEvent{}, 0, nil). + Once() + }, + expectedStatus: http.StatusOK, + expectedCounter: "0", + expectedBody: `[]` + "\n", + }, + { + title: "success to list session's events with one item", + req: &requests.SessionListEvents{ + UID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + Paginator: query.Paginator{Page: 1, PerPage: 10}, + Sorter: query.Sorter{By: "name", Order: "asc"}, + Filters: query.Filters{}, + }, + requiredMocks: func() { + mock. + On("ListEventsSession", + gomock.Anything, + models.UID("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + gomock.Anything, + gomock.Anything, + gomock.Anything, + ). + Return([]models.SessionEvent{ + {}, + }, 1, nil). + Once() + }, + expectedStatus: http.StatusOK, + expectedCounter: "1", + expectedBody: `[{"session":"","type":"","timestamp":"0001-01-01T00:00:00Z","data":null,"seat":0}]` + "\n", + }, + { + title: "success to list session's events with more than one item", + req: &requests.SessionListEvents{ + UID: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + Paginator: query.Paginator{Page: 1, PerPage: 10}, + Sorter: query.Sorter{By: "name", Order: "asc"}, + Filters: query.Filters{}, + }, + requiredMocks: func() { + mock. + On("ListEventsSession", + gomock.Anything, + models.UID("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), + gomock.Anything, + gomock.Anything, + gomock.Anything, + ). + Return([]models.SessionEvent{ + {}, + {}, + }, 2, nil). + Once() + }, + expectedStatus: http.StatusOK, + expectedCounter: "2", + expectedBody: `[{"session":"","type":"","timestamp":"0001-01-01T00:00:00Z","data":null,"seat":0},{"session":"","type":"","timestamp":"0001-01-01T00:00:00Z","data":null,"seat":0}]` + "\n", + }, + } + + for _, tc := range cases { + t.Run(tc.title, func(t *testing.T) { + tc.requiredMocks() + + urlVal := &url.Values{} + urlVal.Set("page", strconv.Itoa(tc.req.Page)) + urlVal.Set("per_page", strconv.Itoa(tc.req.PerPage)) + urlVal.Set("sort_by", tc.req.By) + urlVal.Set("order_by", tc.req.Order) + + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/sessions/%s/events?"+urlVal.Encode(), tc.req.UID), nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Role", authorizer.RoleOwner.String()) + + rec := httptest.NewRecorder() + + e := NewRouter(mock) + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectedStatus, rec.Result().StatusCode) + assert.Equal(t, tc.expectedCounter, rec.Header().Get("X-Total-Count")) + assert.Equal(t, tc.expectedBody, rec.Body.String()) + }) + } + + mock.AssertExpectations(t) +} diff --git a/api/services/mocks/services.go b/api/services/mocks/services.go index c796dba66b1..d7eb046715f 100644 --- a/api/services/mocks/services.go +++ b/api/services/mocks/services.go @@ -728,24 +728,6 @@ func (_m *Service) EvaluateKeyUsername(ctx context.Context, key *models.PublicKe return r0, r1 } -// EventSession provides a mock function with given fields: ctx, uid, event -func (_m *Service) EventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error { - ret := _m.Called(ctx, uid, event) - - if len(ret) == 0 { - panic("no return value specified for EventSession") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, models.UID, *models.SessionEvent) error); ok { - r0 = rf(ctx, uid, event) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // GetDevice provides a mock function with given fields: ctx, uid func (_m *Service) GetDevice(ctx context.Context, uid models.UID) (*models.Device, error) { ret := _m.Called(ctx, uid) @@ -1141,6 +1123,43 @@ func (_m *Service) ListDevices(ctx context.Context, req *requests.DeviceList) ([ return r0, r1, r2 } +// ListEventsSession provides a mock function with given fields: ctx, uid, paginator, filters, sorter +func (_m *Service) ListEventsSession(ctx context.Context, uid models.UID, paginator query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.SessionEvent, int, error) { + ret := _m.Called(ctx, uid, paginator, filters, sorter) + + if len(ret) == 0 { + panic("no return value specified for ListEventsSession") + } + + var r0 []models.SessionEvent + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) ([]models.SessionEvent, int, error)); ok { + return rf(ctx, uid, paginator, filters, sorter) + } + if rf, ok := ret.Get(0).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) []models.SessionEvent); ok { + r0 = rf(ctx, uid, paginator, filters, sorter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]models.SessionEvent) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) int); ok { + r1 = rf(ctx, uid, paginator, filters, sorter) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) error); ok { + r2 = rf(ctx, uid, paginator, filters, sorter) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + // ListNamespaces provides a mock function with given fields: ctx, req func (_m *Service) ListNamespaces(ctx context.Context, req *requests.NamespaceList) ([]models.Namespace, int, error) { ret := _m.Called(ctx, req) @@ -1452,6 +1471,24 @@ func (_m *Service) ResolveDevice(ctx context.Context, req *requests.ResolveDevic return r0, r1 } +// SaveEventSession provides a mock function with given fields: ctx, uid, event +func (_m *Service) SaveEventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error { + ret := _m.Called(ctx, uid, event) + + if len(ret) == 0 { + panic("no return value specified for SaveEventSession") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID, *models.SessionEvent) error); ok { + r0 = rf(ctx, uid, event) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Setup provides a mock function with given fields: ctx, req func (_m *Service) Setup(ctx context.Context, req requests.Setup) error { ret := _m.Called(ctx, req) diff --git a/api/services/session.go b/api/services/session.go index 7620e2d51a7..fc498cf2d4a 100644 --- a/api/services/session.go +++ b/api/services/session.go @@ -16,7 +16,8 @@ type SessionService interface { DeactivateSession(ctx context.Context, uid models.UID) error KeepAliveSession(ctx context.Context, uid models.UID) error UpdateSession(ctx context.Context, uid models.UID, model models.SessionUpdate) error - EventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error + SaveEventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error + ListEventsSession(ctx context.Context, uid models.UID, paginator query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.SessionEvent, int, error) } func (s *service) ListSessions(ctx context.Context, paginator query.Paginator) ([]models.Session, int, error) { @@ -71,7 +72,7 @@ func (s *service) UpdateSession(ctx context.Context, uid models.UID, model model return s.store.SessionUpdate(ctx, uid, sess, &model) } -func (s *service) EventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error { +func (s *service) SaveEventSession(ctx context.Context, uid models.UID, event *models.SessionEvent) error { sess, err := s.store.SessionGet(ctx, uid) if err != nil { return NewErrSessionNotFound(uid, err) @@ -79,3 +80,12 @@ func (s *service) EventSession(ctx context.Context, uid models.UID, event *model return s.store.SessionEvent(ctx, models.UID(sess.UID), event) } + +func (s *service) ListEventsSession(ctx context.Context, uid models.UID, paginator query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.SessionEvent, int, error) { + sess, err := s.store.SessionGet(ctx, uid) + if err != nil { + return nil, 0, NewErrSessionNotFound(uid, err) + } + + return s.store.SessionListEvents(ctx, models.UID(sess.UID), paginator, filters, sorter) +} diff --git a/api/services/session_test.go b/api/services/session_test.go index 307f65b4173..c4a5cd5c68d 100644 --- a/api/services/session_test.go +++ b/api/services/session_test.go @@ -2,11 +2,10 @@ package services import ( "context" + "errors" "net" "testing" - goerrors "errors" - "github.com/shellhub-io/shellhub/api/store" "github.com/shellhub-io/shellhub/api/store/mocks" "github.com/shellhub-io/shellhub/pkg/api/query" @@ -16,6 +15,7 @@ import ( mocksGeoIp "github.com/shellhub-io/shellhub/pkg/geoip/mocks" "github.com/shellhub-io/shellhub/pkg/models" "github.com/stretchr/testify/assert" + mocker "github.com/stretchr/testify/mock" ) func TestListSessions(t *testing.T) { @@ -40,12 +40,12 @@ func TestListSessions(t *testing.T) { paginator: query.Paginator{Page: 1, PerPage: 10}, requiredMocks: func(paginator query.Paginator) { mock.On("SessionList", ctx, paginator). - Return(nil, 0, goerrors.New("error")).Once() + Return(nil, 0, errors.New("error")).Once() }, expected: Expected{ sessions: nil, count: 0, - err: goerrors.New("error"), + err: errors.New("error"), }, }, { @@ -111,11 +111,11 @@ func TestGetSession(t *testing.T) { uid: models.UID("_uid"), requiredMocks: func() { mock.On("SessionGet", ctx, models.UID("_uid")). - Return(nil, goerrors.New("error")).Once() + Return(nil, errors.New("error")).Once() }, expected: Expected{ session: nil, - err: NewErrSessionNotFound(models.UID("_uid"), goerrors.New("error")), + err: NewErrSessionNotFound(models.UID("_uid"), errors.New("error")), }, }, { @@ -165,7 +165,7 @@ func TestCreateSession(t *testing.T) { Longitude: 0, }} - Err := goerrors.New("error") + Err := errors.New("error") cases := []struct { name string @@ -246,9 +246,9 @@ func TestDeactivateSession(t *testing.T) { }, nil).Once() mock.On("SessionDeleteActives", ctx, models.UID("_uid")). - Return(goerrors.New("error")).Once() + Return(errors.New("error")).Once() }, - expected: goerrors.New("error"), + expected: errors.New("error"), }, { name: "succeeds", @@ -298,9 +298,9 @@ func TestUpdateSession(t *testing.T) { description: "fails when SessionGet returns error", requiredMocks: func() { mockStore.On("SessionGet", ctx, uid). - Return(nil, goerrors.New("get error")).Once() + Return(nil, errors.New("get error")).Once() }, - expectedErr: NewErrSessionNotFound(uid, goerrors.New("get error")), + expectedErr: NewErrSessionNotFound(uid, errors.New("get error")), }, { description: "fails when SessionUpdate returns error", @@ -308,9 +308,9 @@ func TestUpdateSession(t *testing.T) { mockStore.On("SessionGet", ctx, uid). Return(sess, nil).Once() mockStore.On("SessionUpdate", ctx, uid, sess, &updateModel). - Return(goerrors.New("update error")).Once() + Return(errors.New("update error")).Once() }, - expectedErr: goerrors.New("update error"), + expectedErr: errors.New("update error"), }, { description: "succeeds when no errors", @@ -335,3 +335,241 @@ func TestUpdateSession(t *testing.T) { mockStore.AssertExpectations(t) } + +func TestListEvents(t *testing.T) { + type Expected struct { + events []models.SessionEvent + counter int + err error + } + + mock := new(mocks.Store) + + tests := []struct { + description string + uid string + paginator query.Paginator + sorter query.Sorter + filters query.Filters + requiredMocks func() + expected Expected + }{ + { + description: "failed to get the session", + uid: "uid", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: "asc"}, + filters: query.Filters{}, + requiredMocks: func() { + mock.On( + "SessionGet", + mocker.Anything, + models.UID("uid"), + ). + Return(nil, errors.New("error")). + Once() + }, + expected: Expected{ + nil, 0, NewErrSessionNotFound(models.UID("uid"), errors.New("error")), + }, + }, + { + description: "failed to list the events", + uid: "uid", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: "asc"}, + filters: query.Filters{}, + requiredMocks: func() { + mock.On( + "SessionGet", + mocker.Anything, + models.UID("uid"), + ). + Return(&models.Session{ + UID: "uid", + }, nil). + Once() + + mock.On( + "SessionListEvents", + mocker.Anything, + models.UID("uid"), + query.Paginator{Page: 1, PerPage: 10}, + query.Filters{}, + query.Sorter{By: "timestamp", Order: "asc"}, + ). + Return(nil, 0, errors.New("error")). + Once() + }, + expected: Expected{ + nil, 0, errors.New("error"), + }, + }, + { + description: "success when session has no events", + uid: "uid", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: "asc"}, + filters: query.Filters{}, + requiredMocks: func() { + mock.On( + "SessionGet", + mocker.Anything, + models.UID("uid"), + ). + Return(&models.Session{ + UID: "uid", + }, nil). + Once() + + mock.On( + "SessionListEvents", + mocker.Anything, + models.UID("uid"), + query.Paginator{Page: 1, PerPage: 10}, + query.Filters{}, + query.Sorter{By: "timestamp", Order: "asc"}, + ). + Return([]models.SessionEvent{}, 0, nil). + Once() + }, + expected: Expected{ + []models.SessionEvent{}, 0, nil, + }, + }, + { + description: "success when session has one event", + uid: "uid", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: "asc"}, + filters: query.Filters{}, + requiredMocks: func() { + mock.On( + "SessionGet", + mocker.Anything, + models.UID("uid"), + ). + Return(&models.Session{ + UID: "uid", + }, nil). + Once() + + mock.On( + "SessionListEvents", + mocker.Anything, + models.UID("uid"), + query.Paginator{Page: 1, PerPage: 10}, + query.Filters{}, + query.Sorter{By: "timestamp", Order: "asc"}, + ). + Return([]models.SessionEvent{ + {}, + }, 1, nil). + Once() + }, + expected: Expected{ + []models.SessionEvent{ + {}, + }, 1, nil, + }, + }, + { + description: "success when session has many events", + uid: "uid", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: "asc"}, + filters: query.Filters{}, + requiredMocks: func() { + mock.On( + "SessionGet", + mocker.Anything, + models.UID("uid"), + ). + Return(&models.Session{ + UID: "uid", + }, nil). + Once() + + mock.On( + "SessionListEvents", + mocker.Anything, + models.UID("uid"), + query.Paginator{Page: 1, PerPage: 10}, + query.Filters{}, + query.Sorter{By: "timestamp", Order: "asc"}, + ). + Return([]models.SessionEvent{ + {}, + {}, + {}, + {}, + }, 4, nil). + Once() + }, + expected: Expected{ + []models.SessionEvent{ + {}, + {}, + {}, + {}, + }, 4, nil, + }, + }, + { + description: "success when session has many events and is paged", + uid: "uid", + paginator: query.Paginator{Page: 1, PerPage: 2}, + sorter: query.Sorter{By: "timestamp", Order: "asc"}, + filters: query.Filters{}, + requiredMocks: func() { + mock.On( + "SessionGet", + mocker.Anything, + models.UID("uid"), + ). + Return(&models.Session{ + UID: "uid", + }, nil). + Once() + + mock.On( + "SessionListEvents", + mocker.Anything, + models.UID("uid"), + query.Paginator{Page: 1, PerPage: 2}, + query.Filters{}, + query.Sorter{By: "timestamp", Order: "asc"}, + ). + Return([]models.SessionEvent{ + {}, + {}, + }, 4, nil). + Once() + }, + expected: Expected{ + []models.SessionEvent{ + {}, + {}, + }, 4, nil, + }, + }, + } + + service := NewService(mock, privateKey, publicKey, storecache.NewNullCache(), clientMock) + + for _, test := range tests { + t.Run(test.description, func(t *testing.T) { + ctx := context.Background() + test.requiredMocks() + + events, counter, err := service.ListEventsSession(ctx, models.UID(test.uid), test.paginator, test.filters, test.sorter) + assert.Equal(t, test.expected, Expected{ + events: events, + counter: counter, + err: err, + }) + }) + } + + mock.AssertExpectations(t) +} diff --git a/api/store/mocks/store.go b/api/store/mocks/store.go index be4e8237ca9..fb6fcbc55e1 100644 --- a/api/store/mocks/store.go +++ b/api/store/mocks/store.go @@ -1555,9 +1555,9 @@ func (_m *Store) SessionList(ctx context.Context, paginator query.Paginator) ([] return r0, r1, r2 } -// SessionListEvents provides a mock function with given fields: ctx, uid, seat, event, paginator -func (_m *Store) SessionListEvents(ctx context.Context, uid models.UID, seat int, event models.SessionEventType, paginator query.Paginator) ([]models.SessionEvent, int, error) { - ret := _m.Called(ctx, uid, seat, event, paginator) +// SessionListEvents provides a mock function with given fields: ctx, uid, paginator, filters, sorter +func (_m *Store) SessionListEvents(ctx context.Context, uid models.UID, paginator query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.SessionEvent, int, error) { + ret := _m.Called(ctx, uid, paginator, filters, sorter) if len(ret) == 0 { panic("no return value specified for SessionListEvents") @@ -1566,25 +1566,25 @@ func (_m *Store) SessionListEvents(ctx context.Context, uid models.UID, seat int var r0 []models.SessionEvent var r1 int var r2 error - if rf, ok := ret.Get(0).(func(context.Context, models.UID, int, models.SessionEventType, query.Paginator) ([]models.SessionEvent, int, error)); ok { - return rf(ctx, uid, seat, event, paginator) + if rf, ok := ret.Get(0).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) ([]models.SessionEvent, int, error)); ok { + return rf(ctx, uid, paginator, filters, sorter) } - if rf, ok := ret.Get(0).(func(context.Context, models.UID, int, models.SessionEventType, query.Paginator) []models.SessionEvent); ok { - r0 = rf(ctx, uid, seat, event, paginator) + if rf, ok := ret.Get(0).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) []models.SessionEvent); ok { + r0 = rf(ctx, uid, paginator, filters, sorter) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]models.SessionEvent) } } - if rf, ok := ret.Get(1).(func(context.Context, models.UID, int, models.SessionEventType, query.Paginator) int); ok { - r1 = rf(ctx, uid, seat, event, paginator) + if rf, ok := ret.Get(1).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) int); ok { + r1 = rf(ctx, uid, paginator, filters, sorter) } else { r1 = ret.Get(1).(int) } - if rf, ok := ret.Get(2).(func(context.Context, models.UID, int, models.SessionEventType, query.Paginator) error); ok { - r2 = rf(ctx, uid, seat, event, paginator) + if rf, ok := ret.Get(2).(func(context.Context, models.UID, query.Paginator, query.Filters, query.Sorter) error); ok { + r2 = rf(ctx, uid, paginator, filters, sorter) } else { r2 = ret.Error(2) } diff --git a/api/store/mongo/session.go b/api/store/mongo/session.go index 882a8d83057..6f65cb0c0cf 100644 --- a/api/store/mongo/session.go +++ b/api/store/mongo/session.go @@ -9,7 +9,9 @@ import ( "github.com/shellhub-io/shellhub/pkg/api/query" "github.com/shellhub-io/shellhub/pkg/clock" "github.com/shellhub-io/shellhub/pkg/models" + log "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" @@ -350,41 +352,93 @@ func (s *Store) SessionEvent(ctx context.Context, uid models.UID, event *models. return nil } -func (s *Store) SessionListEvents(ctx context.Context, uid models.UID, seat int, event models.SessionEventType, paginator query.Paginator) ([]models.SessionEvent, int, error) { - query := []bson.M{ +func (s *Store) SessionListEvents(ctx context.Context, uid models.UID, paginator query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.SessionEvent, int, error) { + pipeline := []bson.M{ { "$match": bson.M{ "session": uid, - "seat": seat, - "type": event, - }, - }, - { - "$sort": bson.M{ - "timestamp": 1, }, }, } - queryCount := query - queryCount = append(queryCount, bson.M{"$count": "count"}) - count, err := AggregateCount(ctx, s.db.Collection("sessions_events"), queryCount) + queryMatch, err := queries.FromFilters(&filters) if err != nil { + log.WithError(err).Error("failed to create filters") + return nil, 0, FromMongoError(err) } - query = append(query, queries.FromPaginator(&paginator)...) + pipeline = append(pipeline, queryMatch...) - cursosr, err := s.db.Collection("sessions_events").Aggregate(ctx, query) + countPipeline := append(pipeline, bson.M{"$count": "count"}) + count, err := AggregateCount(ctx, s.db.Collection("sessions_events"), countPipeline) if err != nil { + log.WithError(err).Error("failed to count sessions_events") + return nil, 0, FromMongoError(err) } - events := make([]models.SessionEvent, 0) - if err := cursosr.All(ctx, events); err != nil { + if sorter.By == "" { + sorter.By = "timestamp" + } + + pipeline = append(pipeline, queries.FromSorter(&sorter)...) + pipeline = append(pipeline, queries.FromPaginator(&paginator)...) + + opts := options.Aggregate().SetAllowDiskUse(true) + cursor, err := s.db.Collection("sessions_events").Aggregate(ctx, pipeline, opts) + if err != nil { + log.WithError(err).Error("failed to run aggregation against sessions_events collection") + return nil, 0, FromMongoError(err) } + defer cursor.Close(ctx) + + events := make([]models.SessionEvent, 0) + for cursor.Next(ctx) { + var event models.SessionEvent + if err := cursor.Decode(&event); err != nil { + log.WithError(err).Error("failed to decode the event from the cursor") + + return nil, 0, err + } + + switch event.Type { + case models.SessionEventTypeWindowChange: + prim := event.Data.(primitive.D) + + data, err := bson.Marshal(prim) + if err != nil { + return nil, 0, err + } + + model := models.SSHWindowChange{} + if err := bson.Unmarshal(data, &model); err != nil { + return nil, 0, err + } + + event.Data = model + case models.SessionEventTypePtyRequest: + // NOTE: We're converting the data returned by MongoDB when the field is a [any] to out structure. + prim := event.Data.(primitive.D) + + data, err := bson.Marshal(prim) + if err != nil { + return nil, 0, err + } + + model := models.SSHPty{} + if err := bson.Unmarshal(data, &model); err != nil { + return nil, 0, err + } + + event.Data = model + } + + events = append(events, event) + } + return events, count, nil } diff --git a/api/store/session.go b/api/store/session.go index 9edb9fab1e0..978f45fca71 100644 --- a/api/store/session.go +++ b/api/store/session.go @@ -19,6 +19,6 @@ type SessionStore interface { SessionSetType(ctx context.Context, uid models.UID, kind string) error SessionCreateActive(ctx context.Context, uid models.UID, session *models.Session) error SessionEvent(ctx context.Context, uid models.UID, event *models.SessionEvent) error - SessionListEvents(ctx context.Context, uid models.UID, seat int, event models.SessionEventType, paginator query.Paginator) ([]models.SessionEvent, int, error) + SessionListEvents(ctx context.Context, uid models.UID, paginator query.Paginator, filters query.Filters, sorter query.Sorter) ([]models.SessionEvent, int, error) SessionDeleteEvents(ctx context.Context, uid models.UID, seat int, event models.SessionEventType) error } diff --git a/pkg/api/requests/session.go b/pkg/api/requests/session.go index f69a14cb87c..edf76005401 100644 --- a/pkg/api/requests/session.go +++ b/pkg/api/requests/session.go @@ -1,6 +1,10 @@ package requests -import "time" +import ( + "time" + + "github.com/shellhub-io/shellhub/pkg/api/query" +) // SessionIDParam is a structure to represent and validate a session UID as path param. type SessionIDParam struct { @@ -57,3 +61,10 @@ type SessionSeat struct { SessionIDParam ID int `json:"id" bson:"id,omitempty"` } + +type SessionListEvents struct { + UID string `param:"uid" validate:"required"` + query.Paginator + query.Filters + query.Sorter +} From f1c11dc929c377f217e69719a582379f467bb3eb Mon Sep 17 00:00:00 2001 From: Henry Barreto Date: Fri, 4 Jul 2025 11:11:23 -0300 Subject: [PATCH 2/3] tests(api): increase coverage of tests to sessions events on store --- api/services/session_test.go | 4 +- api/store/mongo/fixtures/sessions_events.json | 35 ++ api/store/mongo/session_test.go | 484 ++++++++++++++++++ api/store/mongo/store_test.go | 1 + 4 files changed, 522 insertions(+), 2 deletions(-) create mode 100644 api/store/mongo/fixtures/sessions_events.json diff --git a/api/services/session_test.go b/api/services/session_test.go index c4a5cd5c68d..a03c818f731 100644 --- a/api/services/session_test.go +++ b/api/services/session_test.go @@ -232,9 +232,9 @@ func TestDeactivateSession(t *testing.T) { uid: models.UID("_uid"), requiredMocks: func() { mock.On("SessionGet", ctx, models.UID("_uid")). - Return(nil, goerrors.New("get error")).Once() + Return(nil, errors.New("get error")).Once() }, - expected: NewErrSessionNotFound("_uid", goerrors.New("get error")), + expected: NewErrSessionNotFound("_uid", errors.New("get error")), }, { name: "fails", diff --git a/api/store/mongo/fixtures/sessions_events.json b/api/store/mongo/fixtures/sessions_events.json new file mode 100644 index 00000000000..7c0e9536cd5 --- /dev/null +++ b/api/store/mongo/fixtures/sessions_events.json @@ -0,0 +1,35 @@ +{ + "sessions_events": { + "6862c1d617cc3c27e6c77995": { + "_id": "6862c1d617cc3c27e6c77995", + "session": "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + "type": "pty-req", + "timestamp": "2023-01-02T12:00:00.000Z", + "data": { + "term": "screen-256color", + "columns": 211, + "rows": 47, + "width": 1899, + "height": 940, + "modelist": "" + }, + "seat": 0 + }, + "6862c1d617cc3c27e6c77996": { + "_id": "6862c1d617cc3c27e6c77996", + "session": "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + "type": "shell", + "timestamp": "2023-01-02T12:01:00.000Z", + "data": "", + "seat": 0 + }, + "6862c1db17cc3c27e6c779d1": { + "_id": "6862c1db17cc3c27e6c779d1", + "session": "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + "type": "exit-status", + "timestamp": "2023-01-02T12:02:00.000Z", + "data": "AAAAAA==", + "seat": 0 + } + } +} diff --git a/api/store/mongo/session_test.go b/api/store/mongo/session_test.go index 15627f02553..f1376c325aa 100644 --- a/api/store/mongo/session_test.go +++ b/api/store/mongo/session_test.go @@ -7,9 +7,11 @@ import ( "time" "github.com/shellhub-io/shellhub/api/store" + "github.com/shellhub-io/shellhub/api/store/mongo" "github.com/shellhub-io/shellhub/pkg/api/query" "github.com/shellhub-io/shellhub/pkg/models" "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/bson" ) func TestSessionList(t *testing.T) { @@ -554,3 +556,485 @@ func TestSessionDeleteActives(t *testing.T) { }) } } + +func TestSessionListEvents(t *testing.T) { + type Expected struct { + events []models.SessionEvent + count int + err error + } + + cases := []struct { + description string + uid string + paginator query.Paginator + sorter query.Sorter + filters query.Filters + fixtures []string + expected Expected + }{ + { + description: "succeeds when sessions are not found", + uid: "nonexistent", + paginator: query.Paginator{Page: -1, PerPage: -1}, + sorter: query.Sorter{By: "timestamp", Order: query.OrderAsc}, + filters: query.Filters{}, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureActiveSessions, + fixtureSessionsEvents, + }, + expected: Expected{ + events: []models.SessionEvent{}, + count: 0, + err: nil, + }, + }, + { + description: "succeeds when sessions are found", + uid: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + paginator: query.Paginator{Page: -1, PerPage: -1}, + sorter: query.Sorter{By: "timestamp", Order: query.OrderAsc}, + filters: query.Filters{}, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureActiveSessions, + fixtureSessionsEvents, + }, + expected: Expected{ + events: []models.SessionEvent{ + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "pty-req", + Timestamp: time.Date(2023, 1, 2, 12, 0, 0, 0, time.UTC), + Data: models.SSHPty{ + Term: "screen-256color", + Columns: 211, + Rows: 47, + Width: 1899, + Height: 940, + Modelist: []byte{}, + }, + Seat: 0, + }, + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "shell", + Timestamp: time.Date(2023, 1, 2, 12, 1, 0, 0, time.UTC), + Data: "", + Seat: 0, + }, + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "exit-status", + Timestamp: time.Date(2023, 1, 2, 12, 2, 0, 0, time.UTC), + Data: "AAAAAA==", + Seat: 0, + }, + }, + count: 3, + err: nil, + }, + }, + { + description: "succeeds when sessions are found by page are limited", + uid: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + paginator: query.Paginator{Page: 1, PerPage: 2}, + sorter: query.Sorter{By: "timestamp", Order: query.OrderAsc}, + filters: query.Filters{}, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureActiveSessions, + fixtureSessionsEvents, + }, + expected: Expected{ + events: []models.SessionEvent{ + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "pty-req", + Timestamp: time.Date(2023, 1, 2, 12, 0, 0, 0, time.UTC), + Data: models.SSHPty{ + Term: "screen-256color", + Columns: 211, + Rows: 47, + Width: 1899, + Height: 940, + Modelist: []byte{}, + }, + Seat: 0, + }, + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "shell", + Timestamp: time.Date(2023, 1, 2, 12, 1, 0, 0, time.UTC), + Data: "", + Seat: 0, + }, + }, + count: 3, + err: nil, + }, + }, + { + description: "succeeds when filtering by event type", + uid: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: query.OrderAsc}, + filters: query.Filters{ + Data: []query.Filter{ + { + Type: "property", + Params: &query.FilterProperty{ + Name: "type", + Operator: "eq", + Value: "pty-req", + }, + }, + }, + }, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureActiveSessions, + fixtureSessionsEvents, + }, + expected: Expected{ + events: []models.SessionEvent{ + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "pty-req", + Timestamp: time.Date(2023, 1, 2, 12, 0, 0, 0, time.UTC), + Data: models.SSHPty{ + Term: "screen-256color", + Columns: 211, + Rows: 47, + Width: 1899, + Height: 940, + Modelist: []byte{}, + }, + Seat: 0, + }, + }, + count: 1, + err: nil, + }, + }, + { + description: "succeeds when filtering by seat", + uid: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + paginator: query.Paginator{Page: 1, PerPage: 10}, + sorter: query.Sorter{By: "timestamp", Order: query.OrderAsc}, + filters: query.Filters{ + Data: []query.Filter{ + { + Type: "property", + Params: &query.FilterProperty{ + Name: "seat", + Operator: "eq", + Value: 0, // Use integer instead of string + }, + }, + }, + }, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureActiveSessions, + fixtureSessionsEvents, + }, + expected: Expected{ + events: []models.SessionEvent{ + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "pty-req", + Timestamp: time.Date(2023, 1, 2, 12, 0, 0, 0, time.UTC), + Data: models.SSHPty{ + Term: "screen-256color", + Columns: 211, + Rows: 47, + Width: 1899, + Height: 940, + Modelist: []byte{}, + }, + Seat: 0, + }, + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "shell", + Timestamp: time.Date(2023, 1, 2, 12, 1, 0, 0, time.UTC), + Data: "", + Seat: 0, + }, + { + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: "exit-status", + Timestamp: time.Date(2023, 1, 2, 12, 2, 0, 0, time.UTC), + Data: "AAAAAA==", + Seat: 0, + }, + }, + count: 3, + err: nil, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + events, count, err := s.SessionListEvents(ctx, models.UID(tc.uid), tc.paginator, tc.filters, tc.sorter) + + assert.Equal(t, tc.expected, Expected{events: events, count: count, err: err}) + }) + } +} + +func TestSessionEvent(t *testing.T) { + cases := []struct { + description string + uid models.UID + event *models.SessionEvent + fixtures []string + expected error + }{ + { + description: "succeeds when creating a new session event", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + event: &models.SessionEvent{ + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: models.SessionEventTypePtyRequest, + Timestamp: time.Date(2023, 1, 2, 12, 3, 0, 0, time.UTC), + Data: models.SSHPty{ + Term: "xterm-256color", + Columns: 80, + Rows: 24, + Width: 640, + Height: 480, + Modelist: []byte{}, + }, + Seat: 0, + }, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + }, + expected: nil, + }, + { + description: "succeeds when creating a window change event", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + event: &models.SessionEvent{ + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: models.SessionEventTypeWindowChange, + Timestamp: time.Date(2023, 1, 2, 12, 4, 0, 0, time.UTC), + Data: models.SSHWindowChange{ + Columns: 120, + Rows: 30, + Width: 960, + Height: 720, + }, + Seat: 0, + }, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + }, + expected: nil, + }, + { + description: "succeeds when creating an exit status event", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + event: &models.SessionEvent{ + Session: "a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68", + Type: models.SessionEventTypeExitStatus, + Timestamp: time.Date(2023, 1, 2, 12, 5, 0, 0, time.UTC), + Data: "0", + Seat: 0, + }, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + }, + expected: nil, + }, + { + description: "succeeds when session does not exist", + uid: models.UID("nonexistent"), + event: &models.SessionEvent{ + Session: "nonexistent", + Type: models.SessionEventTypePtyRequest, + Timestamp: time.Date(2023, 1, 2, 12, 3, 0, 0, time.UTC), + Data: "", + Seat: 0, + }, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + }, + expected: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + err := s.SessionEvent(ctx, tc.uid, tc.event) + assert.Equal(t, tc.expected, err) + + // Verify the event was created in sessions_events collection + if tc.expected == nil { + var event models.SessionEvent + store := s.(*mongo.Store) + err := store.GetDB().Collection("sessions_events").FindOne(ctx, bson.M{ + "session": tc.event.Session, + "type": tc.event.Type, + "timestamp": tc.event.Timestamp, + "seat": tc.event.Seat, + }).Decode(&event) + assert.NoError(t, err) + assert.Equal(t, tc.event.Session, event.Session) + assert.Equal(t, tc.event.Type, event.Type) + assert.Equal(t, tc.event.Seat, event.Seat) + } + }) + } +} + +func TestSessionDeleteEvents(t *testing.T) { + cases := []struct { + description string + uid models.UID + seat int + eventType models.SessionEventType + fixtures []string + expected error + }{ + { + description: "succeeds when deleting existing events", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + seat: 0, + eventType: models.SessionEventTypePtyRequest, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureSessionsEvents, + }, + expected: nil, + }, + { + description: "succeeds when deleting shell events", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + seat: 0, + eventType: models.SessionEventTypeShell, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureSessionsEvents, + }, + expected: nil, + }, + { + description: "succeeds when deleting exit status events", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + seat: 0, + eventType: models.SessionEventTypeExitStatus, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureSessionsEvents, + }, + expected: nil, + }, + { + description: "succeeds when no events match criteria", + uid: models.UID("a3b0431f5df6a7827945d2e34872a5c781452bc36de42f8b1297fd9ecb012f68"), + seat: 1, + eventType: models.SessionEventTypePtyRequest, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureSessionsEvents, + }, + expected: nil, + }, + { + description: "succeeds when session does not exist", + uid: models.UID("nonexistent"), + seat: 0, + eventType: models.SessionEventTypePtyRequest, + fixtures: []string{ + fixtureNamespaces, + fixtureDevices, + fixtureSessions, + fixtureSessionsEvents, + }, + expected: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + ctx := context.Background() + + assert.NoError(t, srv.Apply(tc.fixtures...)) + t.Cleanup(func() { + assert.NoError(t, srv.Reset()) + }) + + store := s.(*mongo.Store) + countBefore, err := store.GetDB().Collection("sessions_events").CountDocuments(ctx, bson.M{ + "session": tc.uid, + "seat": tc.seat, + "type": tc.eventType, + }) + assert.NoError(t, err) + + totalBefore, err := store.GetDB().Collection("sessions_events").CountDocuments(ctx, bson.M{}) + assert.NoError(t, err) + + err = s.SessionDeleteEvents(ctx, tc.uid, tc.seat, tc.eventType) + assert.Equal(t, tc.expected, err) + + countAfter, err := store.GetDB().Collection("sessions_events").CountDocuments(ctx, bson.M{ + "session": tc.uid, + "seat": tc.seat, + "type": tc.eventType, + }) + assert.NoError(t, err) + assert.Equal(t, int64(0), countAfter) + + totalAfter, err := store.GetDB().Collection("sessions_events").CountDocuments(ctx, bson.M{}) + assert.NoError(t, err) + + assert.Equal(t, totalBefore-countBefore, totalAfter) + }) + } +} diff --git a/api/store/mongo/store_test.go b/api/store/mongo/store_test.go index 41f2bc570ea..d5d9a10bc49 100644 --- a/api/store/mongo/store_test.go +++ b/api/store/mongo/store_test.go @@ -33,6 +33,7 @@ const ( fixtureUsers = "users" // Check "store.mongo.fixtures.users" for fixture iefo fixtureNamespaces = "namespaces" // Check "store.mongo.fixtures.namespaces" for fixture info fixtureRecoveryTokens = "recovery_tokens" // Check "store.mongo.fixtures.recovery_tokens" for fixture info + fixtureSessionsEvents = "sessions_events" // Check "store.mongo.fixtures.sessions_events" for fixture info ) func TestMain(m *testing.M) { From 7a6309f9eef99d459b153cb392e6793921ff38cc Mon Sep 17 00:00:00 2001 From: Henry Barreto Date: Tue, 5 Aug 2025 11:27:02 -0300 Subject: [PATCH 3/3] refactor(api,pkg): save events summary into a session associating with the seat --- api/store/mongo/migrations/main.go | 1 + api/store/mongo/migrations/migration_108.go | 139 ++++++++++++ .../mongo/migrations/migration_108_test.go | 204 ++++++++++++++++++ api/store/mongo/session.go | 22 +- pkg/models/session.go | 5 +- 5 files changed, 366 insertions(+), 5 deletions(-) create mode 100644 api/store/mongo/migrations/migration_108.go create mode 100644 api/store/mongo/migrations/migration_108_test.go diff --git a/api/store/mongo/migrations/main.go b/api/store/mongo/migrations/main.go index dfba0420c5b..6a952f3b28e 100644 --- a/api/store/mongo/migrations/main.go +++ b/api/store/mongo/migrations/main.go @@ -117,6 +117,7 @@ func GenerateMigrations() []migrate.Migration { migration105, migration106, migration107, + migration108, } } diff --git a/api/store/mongo/migrations/migration_108.go b/api/store/mongo/migrations/migration_108.go new file mode 100644 index 00000000000..af4a45e2c34 --- /dev/null +++ b/api/store/mongo/migrations/migration_108.go @@ -0,0 +1,139 @@ +package migrations + +import ( + "context" + + "github.com/shellhub-io/shellhub/pkg/models" + log "github.com/sirupsen/logrus" + migrate "github.com/xakep666/mongo-migrate" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +var migration108 = migrate.Migration{ + Version: 108, + Description: "Migrate session events to session seats structure", + Up: migrate.MigrationFunc(func(ctx context.Context, db *mongo.Database) error { + log.WithFields(log.Fields{"component": "migration", "version": 108, "action": "Up"}).Info("Applying migration") + + cursor, err := db.Collection("sessions").Find(ctx, bson.M{ + "seats": nil, + "events": bson.M{"$exists": true}, + }) + if err != nil { + log.WithError(err).Error("Failed to fetch sessions") + + return err + } + + defer cursor.Close(ctx) + + for cursor.Next(ctx) { + var session struct { + UID string `bson:"uid"` + Seats []models.SessionSeat `bson:"seats"` + } + + if err := cursor.Decode(&session); err != nil { + log.WithError(err).Error("Failed to decode session") + + return err + } + + eventsCursor, err := db.Collection("sessions_events").Find(ctx, bson.M{"session": session.UID}) + if err != nil { + log.WithError(err).WithField("session", session.UID).Error("Failed to fetch events for session") + + return err + } + + eventsBySeat := make(map[int][]string) + + for eventsCursor.Next(ctx) { + var event struct { + Type string `bson:"type"` + Seat int `bson:"seat"` + } + + if err := eventsCursor.Decode(&event); err != nil { + log.WithError(err).Error("Failed to decode event") + + return err + } + + if _, ok := eventsBySeat[event.Seat]; !ok { + eventsBySeat[event.Seat] = []string{} + } + + eventsBySeat[event.Seat] = append(eventsBySeat[event.Seat], event.Type) + } + + eventsCursor.Close(ctx) + + var seats []models.SessionSeat + var seatIDs []int + + eventTypes := make(map[string]bool) + + for seatID, events := range eventsBySeat { + seats = append(seats, models.SessionSeat{ + ID: seatID, + Events: events, + }) + + seatIDs = append(seatIDs, seatID) + + for _, eventType := range events { + eventTypes[eventType] = true + } + } + + var types []string + for eventType := range eventTypes { + types = append(types, eventType) + } + + _, err = db.Collection("sessions").UpdateOne(ctx, + bson.M{"uid": session.UID}, + bson.M{ + "$set": bson.M{ + "seats": seats, + "events": models.SessionEvents{ + Types: types, + Seats: seatIDs, + }, + }, + }) + if err != nil { + log.WithError(err).WithField("session", session.UID).Error("Failed to update session") + + return err + } + } + + log.WithFields(log.Fields{"component": "migration", "version": 108, "action": "Up"}).Info("Migration completed successfully") + + return nil + }), + Down: migrate.MigrationFunc(func(ctx context.Context, db *mongo.Database) error { + log.WithFields(log.Fields{"component": "migration", "version": 108, "action": "Down"}).Info("Reverting migration") + + if _, err := db.Collection("sessions").UpdateMany( + ctx, + bson.M{}, + bson.M{ + "$unset": bson.M{ + "seats": "", + }, + }, + ); err != nil { + log.WithError(err).Error("Failed to revert events migration") + + return err + } + + log.WithFields(log.Fields{"component": "migration", "version": 108, "action": "Down"}).Info("Migration reverted successfully") + + return nil + }), +} diff --git a/api/store/mongo/migrations/migration_108_test.go b/api/store/mongo/migrations/migration_108_test.go new file mode 100644 index 00000000000..eb369b2252b --- /dev/null +++ b/api/store/mongo/migrations/migration_108_test.go @@ -0,0 +1,204 @@ +package migrations + +import ( + "context" + "testing" + "time" + + "github.com/shellhub-io/shellhub/pkg/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + migrate "github.com/xakep666/mongo-migrate" + "go.mongodb.org/mongo-driver/bson" +) + +func TestMigration108Up(t *testing.T) { + ctx := context.Background() + + cases := []struct { + description string + setup func() error + verify func(t *testing.T) + }{ + { + description: "succeeds migrating session events to seats", + setup: func() error { + sessions := []bson.M{ + { + "uid": "session-1", + "events": bson.M{ + "types": []string{"pty-output", "window-change"}, + "seats": []int{0, 1}, + }, + }, + { + "uid": "session-2", + "events": bson.M{ + "types": []string{"pty-output"}, + "seats": []int{0}, + }, + }, + } + + events := []bson.M{ + { + "session": "session-1", + "type": "pty-output", + "seat": 0, + "timestamp": time.Now(), + }, + { + "session": "session-1", + "type": "window-change", + "seat": 0, + "timestamp": time.Now(), + }, + { + "session": "session-1", + "type": "pty-output", + "seat": 1, + "timestamp": time.Now(), + }, + { + "session": "session-2", + "type": "pty-output", + "seat": 0, + "timestamp": time.Now(), + }, + } + + if _, err := c.Database("test").Collection("sessions").InsertMany(ctx, []any{sessions[0], sessions[1]}); err != nil { + return err + } + + if _, err := c.Database("test").Collection("sessions_events").InsertMany(ctx, []any{events[0], events[1], events[2], events[3]}); err != nil { + return err + } + + return nil + }, + verify: func(t *testing.T) { + var session1 struct { + Events models.SessionEvents `bson:"events"` + Seats []models.SessionSeat `bson:"seats"` + } + + err := c.Database("test").Collection("sessions").FindOne(ctx, bson.M{"uid": "session-1"}).Decode(&session1) + assert.NoError(t, err) + + assert.ElementsMatch(t, []string{"pty-output", "window-change"}, session1.Events.Types) + assert.ElementsMatch(t, []int{0, 1}, session1.Events.Seats) + + assert.Len(t, session1.Seats, 2) + for _, seat := range session1.Seats { + switch seat.ID { + case 0: + assert.ElementsMatch(t, []string{"pty-output", "window-change"}, seat.Events) + case 1: + assert.ElementsMatch(t, []string{"pty-output"}, seat.Events) + } + } + + var session2 struct { + Events models.SessionEvents `bson:"events"` + Seats []models.SessionSeat `bson:"seats"` + } + + err = c.Database("test").Collection("sessions").FindOne(ctx, bson.M{"uid": "session-2"}).Decode(&session2) + assert.NoError(t, err) + + assert.ElementsMatch(t, []string{"pty-output"}, session2.Events.Types) + assert.ElementsMatch(t, []int{0}, session2.Events.Seats) + + assert.Len(t, session2.Seats, 1) + for _, seat := range session2.Seats { + switch seat.ID { + case 0: + assert.ElementsMatch(t, []string{"pty-output"}, seat.Events) + } + } + }, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + t.Cleanup(func() { assert.NoError(t, srv.Reset()) }) + + require.NoError(t, tc.setup()) + migrates := migrate.NewMigrate(c.Database("test"), GenerateMigrations()[107]) + require.NoError(t, migrates.Up(ctx, migrate.AllAvailable)) + + tc.verify(t) + }) + } +} + +func TestMigration108Down(t *testing.T) { + ctx := context.Background() + + cases := []struct { + description string + setup func() error + verify func(t *testing.T) + }{ + { + description: "succeeds removing events field while keeping seats structure", + setup: func() error { + sessions := []bson.M{ + { + "uid": "session-1", + "events": bson.M{ + "types": []string{"pty-output", "window-change"}, + "seats": []int{0, 1}, + }, + "seats": []bson.M{ + { + "id": 0, + "events": []string{"pty-output", "window-change"}, + }, + { + "id": 1, + "events": []string{"pty-output"}, + }, + }, + }, + } + + if _, err := c.Database("test").Collection("sessions").InsertMany(ctx, []any{sessions[0]}); err != nil { + return err + } + + return nil + }, + verify: func(t *testing.T) { + var session struct { + UID string `bson:"uid"` + Events models.SessionEvents `bson:"events"` + Seats []models.SessionSeat `bson:"seats"` + } + + err := c.Database("test").Collection("sessions").FindOne(ctx, bson.M{"uid": "session-1"}).Decode(&session) + assert.NoError(t, err) + + assert.ElementsMatch(t, []string{"pty-output", "window-change"}, session.Events.Types) + assert.ElementsMatch(t, []int{0, 1}, session.Events.Seats) + + assert.Len(t, session.Seats, 0, "Seats should be removed") + }, + }, + } + + for _, tc := range cases { + t.Run(tc.description, func(t *testing.T) { + t.Cleanup(func() { assert.NoError(t, srv.Reset()) }) + + require.NoError(t, tc.setup()) + migrates := migrate.NewMigrate(c.Database("test"), GenerateMigrations()[107]) + require.NoError(t, migrates.Up(ctx, migrate.AllAvailable)) + require.NoError(t, migrates.Down(ctx, migrate.AllAvailable)) + + tc.verify(t) + }) + } +} diff --git a/api/store/mongo/session.go b/api/store/mongo/session.go index 6f65cb0c0cf..71d110e4912 100644 --- a/api/store/mongo/session.go +++ b/api/store/mongo/session.go @@ -217,6 +217,7 @@ func (s *Store) SessionCreate(ctx context.Context, session models.Session) (*mod session.StartedAt = clock.Now() session.LastSeen = session.StartedAt session.Recorded = false + session.Seats = []models.SessionSeat{} device, err := s.DeviceResolve(ctx, store.DeviceUIDResolver, string(session.DeviceUID)) if err != nil { @@ -329,15 +330,28 @@ func (s *Store) SessionEvent(ctx context.Context, uid models.UID, event *models. if _, err := session.WithTransaction(ctx, func(ctx mongo.SessionContext) (any, error) { if _, err := s.db.Collection("sessions").UpdateOne(ctx, - bson.M{"uid": uid}, + bson.M{"uid": uid, "seats.id": bson.M{"$ne": event.Seat}}, + bson.M{ + "$push": bson.M{ + "seats": bson.M{ + "id": event.Seat, + "events": bson.A{}, + }, + }, + }, + ); err != nil { + return nil, FromMongoError(err) + } + + if _, err := s.db.Collection("sessions").UpdateOne(ctx, + bson.M{"uid": uid, "seats.id": event.Seat}, bson.M{ "$addToSet": bson.M{ - "events.types": event.Type, - "events.seats": event.Seat, + "seats.$.events": event.Type, }, }, ); err != nil { - return nil, err + return nil, FromMongoError(err) } if _, err := s.db.Collection("sessions_events").InsertOne(ctx, event); err != nil { diff --git a/pkg/models/session.go b/pkg/models/session.go index 910bba29220..aab6386cb64 100644 --- a/pkg/models/session.go +++ b/pkg/models/session.go @@ -26,6 +26,7 @@ type Session struct { Term string `json:"term" bson:"term"` Position SessionPosition `json:"position" bson:"position"` Events SessionEvents `json:"events" bson:"events"` + Seats []SessionSeat `json:"seats" bson:"seats"` } type ActiveSession struct { @@ -109,5 +110,7 @@ type SessionEvents struct { // SessionSeat stores a session's seat. type SessionSeat struct { // ID is the identifier of session's seat. - ID int `json:"id" bson:"id,omitempty"` + ID int `json:"id" bson:"id"` + // Events is a list of events registered in the session's seat. + Events []string `json:"events" bson:"events"` }