From df1eb4314323e13d612339ab7c9e24aad6e34413 Mon Sep 17 00:00:00 2001 From: mirkoCrobu Date: Wed, 12 Nov 2025 15:08:36 +0100 Subject: [PATCH] add stop upgrade endpoint --- cmd/gendoc/docs.go | 15 +++ internal/api/api.go | 1 + internal/api/docs/openapi.yaml | 12 ++ internal/api/handlers/update.go | 10 ++ internal/e2e/client/client.gen.go | 102 +++++++++++++++ internal/update/apt/service.go | 197 ++++++++++++++++------------- internal/update/arduino/arduino.go | 103 ++++++++++----- internal/update/event.go | 3 + internal/update/update.go | 67 +++++++++- 9 files changed, 387 insertions(+), 123 deletions(-) diff --git a/cmd/gendoc/docs.go b/cmd/gendoc/docs.go index 8e85e2a8..7adaf8a6 100644 --- a/cmd/gendoc/docs.go +++ b/cmd/gendoc/docs.go @@ -824,6 +824,21 @@ Contains a JSON object with the details of an error. {StatusCode: http.StatusNoContent, Reference: "#/components/responses/NoContent"}, }, }, + { + OperationId: "stopUpdate", + Method: http.MethodPut, + Path: "/v1/system/update/stop", + CustomSuccessResponse: &CustomResponseDef{ + Description: "Successful response", + StatusCode: http.StatusOK, + }, + Description: "Stop the upgrade process.", + Summary: "Stop the upgrade process in background", + Tags: []Tag{SystemTag}, + PossibleErrors: []ErrorResponse{ + {StatusCode: http.StatusConflict, Reference: "#/components/responses/Conflict"}, + }, + }, { OperationId: "applyUpdate", Method: http.MethodPut, diff --git a/internal/api/api.go b/internal/api/api.go index 08d31d84..1a346df8 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -66,6 +66,7 @@ func NewHTTPRouter( mux.Handle("GET /v1/system/update/check", handlers.HandleCheckUpgradable(updater)) mux.Handle("GET /v1/system/update/events", handlers.HandleUpdateEvents(updater)) mux.Handle("PUT /v1/system/update/apply", handlers.HandleUpdateApply(updater)) + mux.Handle("PUT /v1/system/update/stop", handlers.HandlerUpdateStop(updater)) mux.Handle("GET /v1/system/resources", handlers.HandleSystemResources()) mux.Handle("GET /v1/models", handlers.HandleModelsList(modelsIndex)) diff --git a/internal/api/docs/openapi.yaml b/internal/api/docs/openapi.yaml index f2b3a999..dd4c82f4 100644 --- a/internal/api/docs/openapi.yaml +++ b/internal/api/docs/openapi.yaml @@ -1080,6 +1080,18 @@ paths: summary: SSE stream of the update process tags: - System + /v1/system/update/stop: + put: + description: Stop the upgrade process. + operationId: stopUpdate + responses: + "200": + description: Successful response + "409": + $ref: '#/components/responses/Conflict' + summary: Stop the upgrade process in background + tags: + - System /v1/version: get: description: returns the application current version diff --git a/internal/api/handlers/update.go b/internal/api/handlers/update.go index 41ac992b..53b5ce36 100644 --- a/internal/api/handlers/update.go +++ b/internal/api/handlers/update.go @@ -107,6 +107,16 @@ func HandleUpdateApply(updater *update.Manager) http.HandlerFunc { } } +func HandlerUpdateStop(updater *update.Manager) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if updater.StopUpgrade() { + render.EncodeResponse(w, http.StatusOK, "Upgrade operation cancellation requested") + } else { + render.EncodeResponse(w, http.StatusConflict, models.ErrorResponse{Details: "No upgrade operation in progress"}) + } + } +} + func HandleUpdateEvents(updater *update.Manager) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { sseStream, err := render.NewSSEStream(r.Context(), w) diff --git a/internal/e2e/client/client.gen.go b/internal/e2e/client/client.gen.go index f6094430..bad855af 100644 --- a/internal/e2e/client/client.gen.go +++ b/internal/e2e/client/client.gen.go @@ -666,6 +666,9 @@ type ClientInterface interface { // EventsUpdate request EventsUpdate(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) + // StopUpdate request + StopUpdate(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) + // GetVersions request GetVersions(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) } @@ -1150,6 +1153,18 @@ func (c *Client) EventsUpdate(ctx context.Context, reqEditors ...RequestEditorFn return c.Client.Do(req) } +func (c *Client) StopUpdate(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewStopUpdateRequest(c.Server) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) GetVersions(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewGetVersionsRequest(c.Server) if err != nil { @@ -2674,6 +2689,33 @@ func NewEventsUpdateRequest(server string) (*http.Request, error) { return req, nil } +// NewStopUpdateRequest generates requests for StopUpdate +func NewStopUpdateRequest(server string) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/v1/system/update/stop") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("PUT", queryURL.String(), nil) + if err != nil { + return nil, err + } + + return req, nil +} + // NewGetVersionsRequest generates requests for GetVersions func NewGetVersionsRequest(server string) (*http.Request, error) { var err error @@ -2858,6 +2900,9 @@ type ClientWithResponsesInterface interface { // EventsUpdateWithResponse request EventsUpdateWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*EventsUpdateResp, error) + // StopUpdateWithResponse request + StopUpdateWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*StopUpdateResp, error) + // GetVersionsWithResponse request GetVersionsWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetVersionsResp, error) } @@ -3674,6 +3719,28 @@ func (r EventsUpdateResp) StatusCode() int { return 0 } +type StopUpdateResp struct { + Body []byte + HTTPResponse *http.Response + JSON409 *Conflict +} + +// Status returns HTTPResponse.Status +func (r StopUpdateResp) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r StopUpdateResp) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type GetVersionsResp struct { Body []byte HTTPResponse *http.Response @@ -4051,6 +4118,15 @@ func (c *ClientWithResponses) EventsUpdateWithResponse(ctx context.Context, reqE return ParseEventsUpdateResp(rsp) } +// StopUpdateWithResponse request returning *StopUpdateResp +func (c *ClientWithResponses) StopUpdateWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*StopUpdateResp, error) { + rsp, err := c.StopUpdate(ctx, reqEditors...) + if err != nil { + return nil, err + } + return ParseStopUpdateResp(rsp) +} + // GetVersionsWithResponse request returning *GetVersionsResp func (c *ClientWithResponses) GetVersionsWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*GetVersionsResp, error) { rsp, err := c.GetVersions(ctx, reqEditors...) @@ -5392,6 +5468,32 @@ func ParseEventsUpdateResp(rsp *http.Response) (*EventsUpdateResp, error) { return response, nil } +// ParseStopUpdateResp parses an HTTP response from a StopUpdateWithResponse call +func ParseStopUpdateResp(rsp *http.Response) (*StopUpdateResp, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &StopUpdateResp{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 409: + var dest Conflict + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON409 = &dest + + } + + return response, nil +} + // ParseGetVersionsResp parses an HTTP response from a GetVersionsWithResponse call func ParseGetVersionsResp(rsp *http.Response) (*GetVersionsResp, error) { bodyBytes, err := io.ReadAll(rsp.Body) diff --git a/internal/update/apt/service.go b/internal/update/apt/service.go index f3d3984e..f351e162 100644 --- a/internal/update/apt/service.go +++ b/internal/update/apt/service.go @@ -18,6 +18,7 @@ package apt import ( "bufio" "context" + "errors" "fmt" "io" "iter" @@ -25,6 +26,7 @@ import ( "regexp" "strings" "sync" + "syscall" "time" "github.com/arduino/go-paths-helper" @@ -87,16 +89,21 @@ func (s *Service) UpgradePackages(ctx context.Context, names []string) (<-chan u ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) defer cancel() - eventsCh <- update.Event{Type: update.StartEvent, Data: "Upgrade is starting"} + eventsCh <- update.Event{Type: update.StartEvent, Data: "deb packages upgrade is starting"} stream := runUpgradeCommand(ctx, names) for line, err := range stream { if err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error running upgrade command", + if errors.Is(err, context.Canceled) { + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Run upgrade operation canceled"} + slog.Info("Upgrade operation canceled by user") + } else { + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error running upgrade command", + } + slog.Error("error processing upgrade command output", "error", err) } - slog.Error("error processing upgrade command output", "error", err) return } eventsCh <- update.Event{Type: update.UpgradeLineEvent, Data: line} @@ -121,13 +128,23 @@ func (s *Service) UpgradePackages(ctx context.Context, names []string) (<-chan u streamCleanup := cleanupDockerContainers(ctx) for line, err := range streamCleanup { if err != nil { - // TODO: maybe we should retun an error or a better feedback to the user? - // currently, we just log the error and continue considenring not blocking - slog.Error("Error stopping and destroying docker containers", "error", err) + if errors.Is(err, context.Canceled) { + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Stop and destroy docker containers and images operation canceled"} + slog.Info("Stop and destroy docker containers and images canceled by user") + return + } else { + // TODO: maybe we should retun an error or a better feedback to the user? + // currently, we just log the error and continue considenring not blocking + slog.Error("Error stopping and destroying docker containers", "error", err) + } } eventsCh <- update.Event{Type: update.UpgradeLineEvent, Data: line} } - + if ctx.Err() != nil { + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Pulling the latest docker images operation canceled"} + slog.Info("Pulling the latest docker images operation canceled by user") + return + } // TEMPORARY PATCH: Install the latest docker images and show the logs to the users. // TODO: Remove this workaround once docker image versions are no longer hardcoded in arduino-app-cli. // Tracking issue: https://github.com/arduino/arduino-app-cli/issues/600 @@ -149,14 +166,25 @@ func (s *Service) UpgradePackages(ctx context.Context, names []string) (<-chan u } eventsCh <- update.Event{Type: update.RestartEvent, Data: "Upgrade completed. Restarting ..."} + if ctx.Err() != nil { + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Upgrade operation canceled"} + slog.Info("Upgrade operation canceled by user") + return + } + err := restartServices(ctx) if err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error restart services after upgrade", + if errors.Is(err, context.Canceled) { + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Restarting operation canceled"} + slog.Info("Upgrade operation canceled by user") + } else { + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error restart services after upgrade", + } + slog.Error("failed to restart services", "error", err) } - slog.Error("failed to restart services", "error", err) return } }() @@ -190,39 +218,15 @@ func runUpdateCommand(ctx context.Context) error { func runUpgradeCommand(ctx context.Context, names []string) iter.Seq2[string, error] { env := []string{"NEEDRESTART_MODE=l"} + args := append([]string{"sudo", "apt-get", "install", "--only-upgrade", "-y"}, names...) - aptOptions := []string{ - "-o", "Acquire::Retries=3", - "-o", "Acquire::http::Timeout=30", - "-o", "Acquire::https::Timeout=30", - } - args := []string{"sudo", "apt-get", "install", "--only-upgrade", "-y"} - args = append(args, aptOptions...) - args = append(args, names...) - - return func(yield func(string, error) bool) { - cmd, err := paths.NewProcess(env, args...) - if err != nil { - _ = yield("", err) - return - } - - stdout := orchestrator.NewCallbackWriter(func(line string) { - if !yield(line, nil) { - if err := cmd.Kill(); err != nil { - slog.Error("Failed to kill upgrade command", slog.String("error", err.Error())) - } - } - }) - cmd.RedirectStderrTo(stdout) - cmd.RedirectStdoutTo(stdout) - - if err := cmd.RunWithinContext(ctx); err != nil { - _ = yield("", err) - return + upgradeCmd, err := paths.NewProcess(env, args...) + if err != nil { + return func(yield func(string, error) bool) { + yield("", err) } } - + return runWithLogStream(ctx, upgradeCmd) } func runAptCleanCommand(ctx context.Context) iter.Seq2[string, error] { @@ -251,54 +255,24 @@ func runAptCleanCommand(ctx context.Context) iter.Seq2[string, error] { } func pullDockerImages(ctx context.Context) iter.Seq2[string, error] { - return func(yield func(string, error) bool) { - cmd, err := paths.NewProcess(nil, "arduino-app-cli", "system", "init") - if err != nil { - _ = yield("", err) - return - } - - stdout := orchestrator.NewCallbackWriter(func(line string) { - if !yield(line, nil) { - if err := cmd.Kill(); err != nil { - slog.Error("Failed to kill 'arduino-app-cli system init' command", slog.String("error", err.Error())) - } - } - }) - cmd.RedirectStderrTo(stdout) - cmd.RedirectStdoutTo(stdout) - - if err = cmd.RunWithinContext(ctx); err != nil { - _ = yield("", err) - return + cmd, err := paths.NewProcess(nil, "arduino-app-cli", "system", "init") + if err != nil { + return func(yield func(string, error) bool) { + yield("", err) } } + return runWithLogStream(ctx, cmd) } // Remove all stopped containers func cleanupDockerContainers(ctx context.Context) iter.Seq2[string, error] { - return func(yield func(string, error) bool) { - cmd, err := paths.NewProcess(nil, "arduino-app-cli", "system", "cleanup") - if err != nil { - _ = yield("", err) - return - } - - stdout := orchestrator.NewCallbackWriter(func(line string) { - if !yield(line, nil) { - if err := cmd.Kill(); err != nil { - slog.Error("Failed to kill 'arduino-app-cli system cleanup' command", slog.String("error", err.Error())) - } - } - }) - cmd.RedirectStderrTo(stdout) - cmd.RedirectStdoutTo(stdout) - - if err = cmd.RunWithinContext(ctx); err != nil { - _ = yield("", err) - return + cmd, err := paths.NewProcess(nil, "arduino-app-cli", "system", "cleanup") + if err != nil { + return func(yield func(string, error) bool) { + yield("", err) } } + return runWithLogStream(ctx, cmd) } // RestartServices restarts services that need to be restarted after an upgrade. @@ -312,7 +286,7 @@ func restartServices(ctx context.Context) error { if err != nil { return err } - return needRestartCmd.RunWithinContext(ctx) + return runWithSigterm(ctx, needRestartCmd) } func listUpgradablePackages(ctx context.Context, matcher func(update.UpgradablePackage) bool) ([]update.UpgradablePackage, error) { @@ -371,3 +345,54 @@ func parseListUpgradableOutput(r io.Reader) []update.UpgradablePackage { } return res } +func runWithLogStream(ctx context.Context, cmd *paths.Process) iter.Seq2[string, error] { + return func(yield func(string, error) bool) { + outputWriter := orchestrator.NewCallbackWriter(func(line string) { + if !yield(line, nil) { + err := cmd.Kill() + if err != nil { + slog.Error("Failed to kill command after yield failed", "command", strings.Join(cmd.GetArgs(), " "), "error", err.Error()) + } + return + } + }) + + cmd.RedirectStderrTo(outputWriter) + cmd.RedirectStdoutTo(outputWriter) + + go func() { + <-ctx.Done() + slog.Debug("Context canceled, sending SIGTERM to process", "command", strings.Join(cmd.GetArgs(), " ")) + err := cmd.Signal(syscall.SIGTERM) + if err != nil { + slog.Warn("Failed to send SIGTERM to process", "command", strings.Join(cmd.GetArgs(), " "), "error", err.Error()) + } + }() + if err := runWithSigterm(ctx, cmd); err != nil { + if ctx.Err() != nil { + _ = yield("", ctx.Err()) + } else { + _ = yield("", err) + } + } + } +} + +func runWithSigterm(ctx context.Context, cmd *paths.Process) error { + go func() { + <-ctx.Done() + slog.Debug("Context canceled, sending SIGTERM to process ", "command", strings.Join(cmd.GetArgs(), " ")) + err := cmd.Signal(syscall.SIGTERM) + if err != nil { + slog.Warn("Failed to send SIGTERM to process", "command", strings.Join(cmd.GetArgs(), " "), "error", err.Error()) + } + }() + if err := cmd.Run(); err != nil { + if ctx.Err() != nil { + return ctx.Err() + } else { + return err + } + } + return nil +} diff --git a/internal/update/arduino/arduino.go b/internal/update/arduino/arduino.go index 01076dff..94b96764 100644 --- a/internal/update/arduino/arduino.go +++ b/internal/update/arduino/arduino.go @@ -149,7 +149,7 @@ func (a *ArduinoPlatformUpdater) UpgradePackages(ctx context.Context, names []st ctx, cancel := context.WithTimeout(ctx, 10*time.Minute) defer cancel() - eventsCh <- update.Event{Type: update.StartEvent, Data: "Upgrade is starting"} + eventsCh <- update.Event{Type: update.StartEvent, Data: "arduino core upgrade is starting"} logrus.SetLevel(logrus.ErrorLevel) // Reduce the log level of arduino-cli srv := commands.NewArduinoCoreServer() @@ -165,10 +165,17 @@ func (a *ArduinoPlatformUpdater) UpgradePackages(ctx context.Context, names []st var inst *rpc.Instance if resp, err := srv.Create(ctx, &rpc.CreateRequest{}); err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error creating Arduino instance", + if ctx.Err() != nil { + slog.Info("Arduino instance creation canceled by user.") + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Arduino instance creation canceled"} + + } else { + slog.Error("Error creating Arduino instance", "error", err) + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error creating Arduino instance", + } } return } else { @@ -183,20 +190,35 @@ func (a *ArduinoPlatformUpdater) UpgradePackages(ctx context.Context, names []st }() { + stream, _ := commands.UpdateIndexStreamResponseToCallbackFunction(ctx, downloadProgressCB) if err := srv.UpdateIndex(&rpc.UpdateIndexRequest{Instance: inst}, stream); err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error updating index", + slog.Info("err.string(): " + err.Error()) + if ctx.Err() != nil { + slog.Info("Update index canceled by user.") + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Update index canceled"} + } else { + slog.Error("Error updating index", "error", err) + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error updating index", + } } return } if err := srv.Init(&rpc.InitRequest{Instance: inst}, commands.InitStreamResponseToCallbackFunction(ctx, nil)); err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error initializing Arduino instance", + if ctx.Err() != nil { + slog.Info("Init Streaming Response canceled by user.") + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Init Streaming Response canceled"} + + } else { + slog.Error("Error initializing Arduino instance", "error", err) + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error initializing Arduino instance", + } } return } @@ -217,20 +239,27 @@ func (a *ArduinoPlatformUpdater) UpgradePackages(ctx context.Context, names []st }, stream, ); err != nil { - var alreadyPresent *cmderrors.PlatformAlreadyAtTheLatestVersionError - if errors.As(err, &alreadyPresent) { - eventsCh <- update.Event{Type: update.UpgradeLineEvent, Data: alreadyPresent.Error()} + + if ctx.Err() != nil { + slog.Info("Platform upgrade canceled by user.") + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Platform upgrade canceled"} return - } + } else { + var alreadyPresent *cmderrors.PlatformAlreadyAtTheLatestVersionError + if errors.As(err, &alreadyPresent) { + eventsCh <- update.Event{Type: update.UpgradeLineEvent, Data: alreadyPresent.Error()} + return + } - var notFound *cmderrors.PlatformNotFoundError - if !errors.As(err, ¬Found) { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error upgrading platform", + var notFound *cmderrors.PlatformNotFoundError + if !errors.As(err, ¬Found) { + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error upgrading platform", + } + return } - return } // If the platform is not found, we will try to install it err := srv.PlatformInstall( @@ -246,10 +275,15 @@ func (a *ArduinoPlatformUpdater) UpgradePackages(ctx context.Context, names []st ), ) if err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error installing platform", + if ctx.Err() != nil { + slog.Info("Platform Install stream canceled by user.") + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "Platform Install stream canceled"} + } else { + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error installing platform", + } } return } @@ -274,10 +308,15 @@ func (a *ArduinoPlatformUpdater) UpgradePackages(ctx context.Context, names []st commands.BurnBootloaderToServerStreams(ctx, cbw, cbw), ) if err != nil { - eventsCh <- update.Event{ - Type: update.ErrorEvent, - Err: err, - Data: "Error burning bootloader", + if ctx.Err() != nil { + slog.Info("burning bootloader operation canceled by user.") + eventsCh <- update.Event{Type: update.CanceledEvent, Data: "burning bootloader operation canceled"} + } else { + eventsCh <- update.Event{ + Type: update.ErrorEvent, + Err: err, + Data: "Error burning bootloader", + } } return } diff --git a/internal/update/event.go b/internal/update/event.go index 0f6c1a51..2cbac519 100644 --- a/internal/update/event.go +++ b/internal/update/event.go @@ -24,6 +24,7 @@ const ( RestartEvent DoneEvent ErrorEvent + CanceledEvent ) // Event represents a single event in the upgrade process. @@ -45,6 +46,8 @@ func (t EventType) String() string { return "done" case ErrorEvent: return "error" + case CanceledEvent: + return "canceled" default: panic("unreachable") } diff --git a/internal/update/update.go b/internal/update/update.go index 7a254478..76f9d50f 100644 --- a/internal/update/update.go +++ b/internal/update/update.go @@ -23,6 +23,7 @@ import ( "net/http" "strings" "sync" + "sync/atomic" "time" "golang.org/x/sync/errgroup" @@ -57,8 +58,10 @@ type Manager struct { debUpdateService ServiceUpdater arduinoPlatformUpdateService ServiceUpdater - mu sync.RWMutex - subs map[chan Event]struct{} + mu sync.RWMutex + subs map[chan Event]struct{} + currentUpgradeCancel atomic.Pointer[context.CancelFunc] + currentCheckCancel atomic.Pointer[context.CancelFunc] } func NewManager(debUpdateService ServiceUpdater, arduinoPlatformUpdateService ServiceUpdater) *Manager { @@ -75,6 +78,12 @@ func (m *Manager) ListUpgradablePackages(ctx context.Context, matcher func(Upgra } defer m.lock.Unlock() + opCtx, opCancel := context.WithCancel(ctx) + m.setCurrentCheckCancel(opCancel) + defer func() { + opCancel() + m.setCurrentCheckCancel(nil) + }() // Make sure to be connected to the internet, before checking for updates. // This is needed because the checks below work also when offline (using cached data). if !isConnected() { @@ -82,7 +91,7 @@ func (m *Manager) ListUpgradablePackages(ctx context.Context, matcher func(Upgra } // Get the list of upgradable packages from two sources (deb and platform) in parallel. - g, ctx := errgroup.WithContext(ctx) + g, ctx := errgroup.WithContext(opCtx) var ( debPkgs []UpgradablePackage arduinoPkgs []UpgradablePackage @@ -114,11 +123,11 @@ func (m *Manager) ListUpgradablePackages(ctx context.Context, matcher func(Upgra return append(arduinoPkgs, debPkgs...), nil } -func (m *Manager) UpgradePackages(ctx context.Context, pkgs []UpgradablePackage) error { +func (m *Manager) UpgradePackages(_ context.Context, pkgs []UpgradablePackage) error { if !m.lock.TryLock() { return ErrOperationAlreadyInProgress } - ctx = context.WithoutCancel(ctx) + var debPkgs []string var arduinoPlatform []string for _, v := range pkgs { @@ -134,6 +143,14 @@ func (m *Manager) UpgradePackages(ctx context.Context, pkgs []UpgradablePackage) go func() { defer m.lock.Unlock() + + ctx, cancel := context.WithCancel(context.Background()) + m.setCurrentUpgradeCancel(cancel) + defer func() { + cancel() + m.setCurrentUpgradeCancel(nil) + }() + // We are launching on purpose the update sequentially. The reason is that // the deb pkgs restart the orchestrator, and if we run in parallel the // update of the cores we will end up with inconsistent state, or @@ -152,6 +169,10 @@ func (m *Manager) UpgradePackages(ctx context.Context, pkgs []UpgradablePackage) for e := range arduinoEvents { m.broadcast(e) } + if ctx.Err() != nil { + slog.Info("Update workflow stopped due to cancellation.") + return + } aptEvents, err := m.debUpdateService.UpgradePackages(ctx, debPkgs) if err != nil { @@ -166,11 +187,47 @@ func (m *Manager) UpgradePackages(ctx context.Context, pkgs []UpgradablePackage) for e := range aptEvents { m.broadcast(e) } + if ctx.Err() != nil { + slog.Info("Update workflow stopped due to cancellation.") + return + } m.broadcast(Event{Type: DoneEvent, Data: "Upgrade completed successfully"}) }() return nil } +func (m *Manager) StopUpgrade() bool { + stopped := false + + if cancelFuncPtr := m.currentUpgradeCancel.Swap(nil); cancelFuncPtr != nil { + (*cancelFuncPtr)() + stopped = true + } + + if cancelFuncPtr := m.currentCheckCancel.Swap(nil); cancelFuncPtr != nil { + (*cancelFuncPtr)() + stopped = true + } + + return stopped +} + +func (m *Manager) setCurrentUpgradeCancel(cancel context.CancelFunc) { + if cancel == nil { + m.currentUpgradeCancel.Store(nil) + } else { + m.currentUpgradeCancel.Store(&cancel) + } +} + +func (m *Manager) setCurrentCheckCancel(cancel context.CancelFunc) { + if cancel == nil { + m.currentCheckCancel.Store(nil) + } else { + m.currentCheckCancel.Store(&cancel) + } +} + // Subscribe creates a new channel for receiving APT events. func (b *Manager) Subscribe() chan Event { eventCh := make(chan Event, 100)