From 58d268b61dce07709eba8e36329cac1811b21569 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:40:00 +0700 Subject: [PATCH 01/13] [dev] update dev-tools configs --- .gitignore | 3 +- .golangci.yml | 483 ++++++++++++++++++++++++++++++++++++++++++ .goreleaser.yaml | 67 ++++++ Dockerfile.goreleaser | 27 +++ Makefile | 44 ++-- 5 files changed, 604 insertions(+), 20 deletions(-) create mode 100644 .golangci.yml create mode 100644 .goreleaser.yaml create mode 100644 Dockerfile.goreleaser diff --git a/.gitignore b/.gitignore index 8d846b72..e11852ce 100644 --- a/.gitignore +++ b/.gitignore @@ -50,8 +50,7 @@ go.work .LSOverride # Icon must end with two \r -Icon - +Icon # Thumbnails ._* diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..d3c3aee3 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,483 @@ +# This file is licensed under the terms of the MIT license https://opensource.org/license/mit +# Copyright (c) 2021-2025 Marat Reymers + +## Golden config for golangci-lint v2.5.0 +# +# This is the best config for golangci-lint based on my experience and opinion. +# It is very strict, but not extremely strict. +# Feel free to adapt it to suit your needs. +# If this config helps you, please consider keeping a link to this file (see the next comment). + +# Based on https://gist.github.com/maratori/47a4d00457a92aa426dbd48a18776322 + +version: "2" + +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 50 + +formatters: + enable: + - goimports # checks if the code and import statements are formatted according to the 'goimports' command + - golines # checks if code is formatted, and fixes long lines + - swaggo # formats swaggo comments + + ## you may want to enable + #- gci # checks if code and import statements are formatted, with additional rules + #- gofmt # checks if the code is formatted according to 'gofmt' command + #- gofumpt # enforces a stricter format than 'gofmt', while being backwards compatible + + # All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml + settings: + golines: + # Target maximum line length. + # Default: 100 + max-len: 120 + +linters: + enable: + - asasalint # checks for pass []any as any in variadic func(...any) + - asciicheck # checks that your code does not contain non-ASCII identifiers + - bidichk # checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - canonicalheader # checks whether net/http.Header uses canonical header + - copyloopvar # detects places where loop variables are copied (Go 1.22+) + - cyclop # checks function and package cyclomatic complexity + - depguard # checks if package imports are in a list of acceptable packages + - dupl # tool for code clone detection + - durationcheck # checks for two durations multiplied together + - embeddedstructfieldcheck # checks embedded types in structs + - err113 # [too strict] checks the errors handling expressions + - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases + - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error + - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 + - exhaustive # checks exhaustiveness of enum switch statements + - exhaustruct # [highly recommend to enable] checks if all structure fields are initialized + - exptostd # detects functions from golang.org/x/exp/ that can be replaced by std functions + - fatcontext # detects nested contexts in loops + - forbidigo # forbids identifiers + - funcorder # checks the order of functions, methods, and constructors + - funlen # tool for detection of long functions + - gocheckcompilerdirectives # validates go compiler directive comments (//go:) + - gochecknoglobals # checks that no global variables exist + - gochecknoinits # checks that no init functions are present in Go code + - gochecksumtype # checks exhaustiveness on Go "sum types" + - gocognit # computes and checks the cognitive complexity of functions + - goconst # finds repeated strings that could be replaced by a constant + - gocritic # provides diagnostics that check for bugs, performance and style issues + - gocyclo # computes and checks the cyclomatic complexity of functions + - godoclint # checks Golang's documentation practice + - godot # checks if comments end in a period + - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod + - goprintffuncname # checks that printf-like functions are named with f at the end + - gosec # inspects source code for security problems + - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - iface # checks the incorrect use of interfaces, helping developers avoid interface pollution + - ineffassign # detects when assignments to existing variables are not used + - intrange # finds places where for loops could make use of an integer range + - iotamixing # checks if iotas are being used in const blocks with other non-iota declarations + - loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap) + - makezero # finds slice declarations with non-zero initial length + - mirror # reports wrong mirror patterns of bytes/strings usage + - mnd # detects magic numbers + - musttag # enforces field tags in (un)marshaled structs + - nakedret # finds naked returns in functions greater than a specified function length + - nestif # reports deeply nested if statements + - nilerr # finds the code that returns nil even if it checks that the error is not nil + - nilnesserr # reports that it checks for err != nil, but it returns a different nil value error (powered by nilness and nilerr) + - nilnil # checks that there is no simultaneous return of nil error and an invalid value + - noctx # finds sending http request without context.Context + - nolintlint # reports ill-formed or insufficient nolint directives + - nonamedreturns # reports all named returns + - nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL + - perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative + - predeclared # finds code that shadows one of Go's predeclared identifiers + - promlinter # checks Prometheus metrics naming via promlint + - protogetter # reports direct reads from proto message fields when getters should be used + - reassign # checks that package variables are not reassigned + - recvcheck # checks for receiver type consistency + - revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint + - rowserrcheck # checks whether Err of rows is checked successfully + - sloglint # ensure consistent code style when using log/slog + - spancheck # checks for mistakes with OpenTelemetry/Census spans + - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed + - staticcheck # is a go vet on steroids, applying a ton of static analysis checks + - testableexamples # checks if examples are testable (have an expected output) + - testifylint # checks usage of github.com/stretchr/testify + - testpackage # makes you use a separate _test package + - tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # removes unnecessary type conversions + - unparam # reports unused function parameters + - unqueryvet # detects SELECT * in SQL queries and SQL builders, encouraging explicit column selection + - unused # checks for unused constants, variables, functions and types + - usestdlibvars # detects the possibility to use variables/constants from the Go standard library + - usetesting # reports uses of functions with replacement inside the testing package + - wastedassign # finds wasted assignment statements + - whitespace # detects leading and trailing whitespace + - wrapcheck # checks that errors returned from external packages are wrapped + + ## you may want to enable + #- arangolint # opinionated best practices for arangodb client + #- decorder # checks declaration order and count of types, constants, variables and functions + #- ginkgolinter # [if you use ginkgo/gomega] enforces standards of using ginkgo and gomega + #- godox # detects usage of FIXME, TODO and other keywords inside comments + #- goheader # checks is file header matches to pattern + #- inamedparam # [great idea, but too strict, need to ignore a lot of cases by default] reports interfaces with unnamed method parameters + #- interfacebloat # checks the number of methods inside an interface + #- ireturn # accept interfaces, return concrete types + #- noinlineerr # disallows inline error handling `if err := ...; err != nil {` + #- prealloc # [premature optimization, but can be used in some cases] finds slice declarations that could potentially be preallocated + #- tagalign # checks that struct tags are well aligned + #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope + #- zerologlint # detects the wrong usage of zerolog that a user forgets to dispatch zerolog.Event + + ## disabled + #- containedctx # detects struct contained context.Context field + #- contextcheck # [too many false positives] checks the function whether use a non-inherited context + #- dogsled # checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + #- dupword # [useless without config] checks for duplicate words in the source code + #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted + #- forcetypeassert # [replaced by errcheck] finds forced type assertions + #- gomodguard # [use more powerful depguard] allow and block lists linter for direct Go module dependencies + #- gosmopolitan # reports certain i18n/l10n anti-patterns in your Go codebase + #- grouper # analyzes expression groups + #- importas # enforces consistent import aliases + #- lll # [replaced by golines] reports long lines + #- maintidx # measures the maintainability index of each function + #- misspell # [useless] finds commonly misspelled English words in comments + #- nlreturn # [too strict and mostly code is not more readable] checks for a new line before return and branch statements to increase code clarity + #- paralleltest # [too many false positives] detects missing usage of t.Parallel() method in your Go test + #- tagliatelle # checks the struct tags + #- thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers + #- wsl # [too strict and mostly code is not more readable] whitespace linter forces you to use empty lines + #- wsl_v5 # [too strict and mostly code is not more readable] add or remove empty lines + + # All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml + settings: + cyclop: + # The maximal code complexity to report. + # Default: 10 + max-complexity: 30 + # The maximal average package complexity. + # If it's higher than 0.0 (float) the check is enabled. + # Default: 0.0 + package-average: 10.0 + + depguard: + # Rules to apply. + # + # Variables: + # - File Variables + # Use an exclamation mark `!` to negate a variable. + # Example: `!$test` matches any file that is not a go test file. + # + # `$all` - matches all go files + # `$test` - matches all go test files + # + # - Package Variables + # + # `$gostd` - matches all of go's standard library (Pulled from `GOROOT`) + # + # Default (applies if no custom rules are defined): Only allow $gostd in all files. + rules: + "deprecated": + # List of file globs that will match this list of settings to compare against. + # By default, if a path is relative, it is relative to the directory where the golangci-lint command is executed. + # The placeholder '${base-path}' is substituted with a path relative to the mode defined with `run.relative-path-mode`. + # The placeholder '${config-path}' is substituted with a path relative to the configuration file. + # Default: $all + files: + - "$all" + # List of packages that are not allowed. + # Entries can be a variable (starting with $), a string prefix, or an exact match (if ending with $). + # Default: [] + deny: + - pkg: github.com/golang/protobuf + desc: Use google.golang.org/protobuf instead, see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - pkg: github.com/satori/go.uuid + desc: Use github.com/google/uuid instead, satori's package is not maintained + - pkg: github.com/gofrs/uuid$ + desc: Use github.com/gofrs/uuid/v5 or later, it was not a go module before v5 + "non-test files": + files: + - "!$test" + deny: + - pkg: math/rand$ + desc: Use math/rand/v2 instead, see https://go.dev/blog/randv2 + "non-main files": + files: + - "!**/main.go" + deny: + - pkg: log$ + desc: Use log/slog instead, see https://go.dev/blog/slog + + embeddedstructfieldcheck: + # Checks that sync.Mutex and sync.RWMutex are not used as embedded fields. + # Default: false + forbid-mutex: true + + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + + exhaustive: + # Program elements to check for exhaustiveness. + # Default: [ switch ] + check: + - switch + - map + + exhaustruct: + # List of regular expressions to match type names that should be excluded from processing. + # Anonymous structs can be matched by '' alias. + # Has precedence over `include`. + # Each regular expression must match the full type name, including package path. + # For example, to match type `net/http.Cookie` regular expression should be `.*/http\.Cookie`, + # but not `http\.Cookie`. + # Default: [] + exclude: + # std libs + - ^net.ListenConfig$ + - ^net/http.Client$ + - ^net/http.Cookie$ + - ^net/http.Request$ + - ^net/http.Response$ + - ^net/http.Server$ + - ^net/http.Transport$ + - ^net/url.URL$ + - ^os/exec.Cmd$ + - ^reflect.StructField$ + # public libs + - "^github.com/gofiber/.+Config$" + - "^gopkg.in/telebot.v4.LongPoller$" + - "^gopkg.in/telebot.v4.ReplyMarkup$" + - "^gopkg.in/telebot.v4.Settings$" + - ^github.com/aws/aws-sdk-go-v2/service/s3.+Input$ + - ^github.com/aws/aws-sdk-go-v2/service/s3/types.ObjectIdentifier$ + - ^github.com/mitchellh/mapstructure.DecoderConfig$ + - ^github.com/prometheus/client_golang/.+Opts$ + - ^github.com/secsy/goftp.Config$ + - ^github.com/Shopify/sarama.Config$ + - ^github.com/Shopify/sarama.ProducerMessage$ + - ^github.com/spf13/cobra.Command$ + - ^github.com/spf13/cobra.CompletionOptions$ + - ^github.com/stretchr/testify/mock.Mock$ + - ^github.com/testcontainers/testcontainers-go.+Request$ + - ^github.com/testcontainers/testcontainers-go.FromDockerfile$ + - ^github.com/urfave/cli.v3.ArgumentBase$ + - ^github.com/urfave/cli.v3.Command$ + - ^github.com/urfave/cli.v3.FlagBase$ + - ^golang.org/x/tools/go/analysis.Analyzer$ + - ^google.golang.org/protobuf/.+Options$ + - ^gopkg.in/yaml.v3.Node$ + - ^gorm.io/gorm/clause.+$ + - ^firebase.google.com/go/v4/messaging.Message$ + - ^firebase.google.com/go/v4/messaging.AndroidConfig$ + # Allows empty structures in return statements. + # Default: false + allow-empty-returns: true + + funcorder: + # Checks if the exported methods of a structure are placed before the non-exported ones. + # Default: true + struct-method: false + + funlen: + # Checks the number of lines in a function. + # If lower than 0, disable the check. + # Default: 60 + lines: 100 + # Checks the number of statements in a function. + # If lower than 0, disable the check. + # Default: 40 + statements: 50 + + gochecksumtype: + # Presence of `default` case in switch statements satisfies exhaustiveness, if all members are not listed. + # Default: true + default-signifies-exhaustive: false + + gocognit: + # Minimal code complexity to report. + # Default: 30 (but we recommend 10-20) + min-complexity: 20 + + gocritic: + # Settings passed to gocritic. + # The settings key is the name of a supported gocritic checker. + # The list of supported checkers can be found at https://go-critic.com/overview. + settings: + captLocal: + # Whether to restrict checker to params only. + # Default: true + paramsOnly: false + underef: + # Whether to skip (*x).method() calls where x is a pointer receiver. + # Default: true + skipRecvDeref: false + + godoclint: + # List of rules to enable in addition to the default set. + # Default: empty + enable: + # Assert no unused link in godocs. + # https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link + - no-unused-link + + govet: + # Enable all analyzers. + # Default: false + enable-all: true + # Disable analyzers by name. + # Run `GL_DEBUG=govet golangci-lint run --enable=govet` to see default, all available analyzers, and enabled analyzers. + # Default: [] + disable: + - fieldalignment # too strict + # Settings per analyzer. + settings: + shadow: + # Whether to be strict about shadowing; can be noisy. + # Default: false + strict: true + + inamedparam: + # Skips check for interface methods with only a single parameter. + # Default: false + skip-single-param: true + + mnd: + # List of function patterns to exclude from analysis. + # Values always ignored: `time.Date`, + # `strconv.FormatInt`, `strconv.FormatUint`, `strconv.FormatFloat`, + # `strconv.ParseInt`, `strconv.ParseUint`, `strconv.ParseFloat`. + # Default: [] + ignored-functions: + - args.Error + - flag.Arg + - flag.Duration.* + - flag.Float.* + - flag.Int.* + - flag.Uint.* + - os.Chmod + - os.Mkdir.* + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets.* + - prometheus.LinearBuckets + + nakedret: + # Make an issue if func has more lines of code than this setting, and it has naked returns. + # Default: 30 + max-func-lines: 0 + + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [funlen, gocognit, golines] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + + perfsprint: + # Optimizes into strings concatenation. + # Default: true + strconcat: false + + reassign: + # Patterns for global variable names that are checked for reassignment. + # See https://github.com/curioswitch/go-reassign#usage + # Default: ["EOF", "Err.*"] + patterns: + - ".*" + + rowserrcheck: + # database/sql is always checked. + # Default: [] + packages: + - github.com/jmoiron/sqlx + + sloglint: + # Enforce not using global loggers. + # Values: + # - "": disabled + # - "all": report all global loggers + # - "default": report only the default slog logger + # https://github.com/go-simpler/sloglint?tab=readme-ov-file#no-global + # Default: "" + no-global: all + # Enforce using methods that accept a context. + # Values: + # - "": disabled + # - "all": report all contextless calls + # - "scope": report only if a context exists in the scope of the outermost function + # https://github.com/go-simpler/sloglint?tab=readme-ov-file#context-only + # Default: "" + context: scope + + staticcheck: + # SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks + # Example (to disable some checks): [ "all", "-SA1000", "-SA1001"] + # Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"] + checks: + - all + # Incorrect or missing package comment. + # https://staticcheck.dev/docs/checks/#ST1000 + - -ST1000 + # Use consistent method receiver names. + # https://staticcheck.dev/docs/checks/#ST1016 + - -ST1016 + # Omit embedded fields from selector expression. + # https://staticcheck.dev/docs/checks/#QF1008 + - -QF1008 + + usetesting: + # Enable/disable `os.TempDir()` detections. + # Default: false + os-temp-dir: true + + wrapcheck: + extra-ignore-sigs: + - .JSON( + - .SendStatus( + + exclusions: + generated: lax + # Predefined exclusion rules. + # Default: [] + presets: + - std-error-handling + - common-false-positives + # Excluding configuration per-path, per-linter, per-text and per-source. + rules: + - source: "TODO" + linters: [godot] + - text: "should have a package comment" + linters: [revive] + - text: 'exported \S+ \S+ should have comment( \(or a comment on this block\))? or be unexported' + linters: [revive] + - text: 'package comment should be of the form ".+"' + source: "// ?(nolint|TODO)" + linters: [revive] + - text: 'comment on exported \S+ \S+ should be of the form ".+"' + source: "// ?(nolint|TODO)" + linters: [revive, staticcheck] + - path: '_test\.go' + linters: + - bodyclose + - dupl + - err113 + - errcheck + - exhaustruct + - funlen + - gocognit + - goconst + - gosec + - noctx + - wrapcheck diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 00000000..a66312eb --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,67 @@ +version: 2 +project_name: server + +before: + hooks: + - go install github.com/swaggo/swag/cmd/swag@latest + - go mod tidy + - go generate ./... + +builds: + - main: ./cmd/sms-gateway + env: + - CGO_ENABLED=0 + goos: + - linux + - windows + - darwin + ldflags: + - -s -w + - -X github.com/android-sms-gateway/server/internal/version.AppVersion={{ .Version }} + - -X github.com/android-sms-gateway/server/internal/version.AppRelease={{ .Env.RELEASE_ID }} + +archives: + - formats: ["tar.gz"] + # this name template makes the OS and Arch compatible with the results of `uname`. + name_template: >- + {{ .ProjectName }}_ + {{- title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + # use zip for windows archives + format_overrides: + - goos: windows + formats: ["zip"] + +dockers_v2: + - dockerfile: Dockerfile.goreleaser + extra_files: + - scripts/docker-entrypoint.sh + images: + - "{{ .Env.DOCKER_REGISTRY }}/{{ .ProjectName }}" + tags: + - "{{ .Tag }}" + - "v{{ .Major }}" + - "v{{ .Major }}.{{ .Minor }}" + - "latest" + labels: + "org.opencontainers.image.created": "{{ .Date }}" + "org.opencontainers.image.title": "{{ .ProjectName }}" + "org.opencontainers.image.revision": "{{ .FullCommit }}" + "org.opencontainers.image.version": "{{ .Version }}" + "org.opencontainers.image.name": "{{ .Env.DOCKER_REGISTRY }}/{{ .ProjectName }}" + "org.opencontainers.image.source": "{{ .GitURL }}" + platforms: + - linux/amd64 + - linux/arm64 + build_args: + BINARY_NAME: "{{ .ProjectName }}" + +changelog: + sort: asc + filters: + exclude: + - "^docs:" + - "^test:" diff --git a/Dockerfile.goreleaser b/Dockerfile.goreleaser new file mode 100644 index 00000000..a8c125bb --- /dev/null +++ b/Dockerfile.goreleaser @@ -0,0 +1,27 @@ +FROM alpine:3.22 + +ARG TARGETPLATFORM +ARG BINARY_NAME + +# Install certificates and timezone data +RUN apk add --no-cache ca-certificates tzdata curl + +# Set the Current Working Directory inside the container +WORKDIR /app + +# Copy the Pre-built binary file from GoReleaser +COPY scripts/docker-entrypoint.sh /docker-entrypoint.sh +COPY $TARGETPLATFORM/$BINARY_NAME ./app + +# Command to run the executable +EXPOSE 3000 + +USER guest + +ENTRYPOINT ["/docker-entrypoint.sh"] + +HEALTHCHECK --interval=30s --timeout=5s --start-period=15s --retries=3 \ + CMD curl -fs http://localhost:3000/health/live + +CMD [ "/app/app" ] + diff --git a/Makefile b/Makefile index 9f954ad6..4d36dbf4 100644 --- a/Makefile +++ b/Makefile @@ -1,37 +1,47 @@ project_name = sms-gateway -image_name = capcom6/$(project_name):latest +registry_name = ghcr.io/android-sms-gateway +image_name = ghcr.io/android-sms-gateway/server:latest extension= ifeq ($(OS),Windows_NT) extension = .exe endif -# Default target -all: fmt lint test benchmark +.PHONY: \ + all fmt lint test coverage benchmark deps release clean help \ + init init-dev ngrok air db-upgrade db-upgrade-raw run test-e2e build install \ + docker-build docker docker-dev docker-clean -fmt: +all: fmt lint test benchmark ## Run all tests and checks + +fmt: ## Format the code golangci-lint fmt -# Lint the code using golangci-lint -lint: +lint: ## Lint the code golangci-lint run --timeout=5m -# Run tests with coverage -test: +test: ## Run tests go test -race -shuffle=on -count=1 -covermode=atomic -coverpkg=./... -coverprofile=coverage.out ./... -# Run benchmarks -benchmark: +coverage: test ## Generate coverage + go tool cover -func=coverage.out + go tool cover -html=coverage.out -o coverage.html + +benchmark: ## Run benchmarks go test -run=^$$ -bench=. -benchmem ./... | tee benchmark.txt -# Download dependencies -deps: +deps: ## Install dependencies go mod download -# Clean up generated files -clean: - go clean -cache -testcache - rm -f coverage.out benchmark.txt +release: ## Create release + DOCKER_REGISTRY=$(registry_name) RELEASE_ID=0 goreleaser release --snapshot --clean + +clean: ## Remove build artifacts + rm -f coverage.* benchmark.txt + rm -rf dist + +help: ## Show this help + @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST) ### @@ -77,5 +87,3 @@ docker-dev: docker-clean: docker compose -f deployments/docker-compose/docker-compose.yml down --volumes - -.PHONY: all fmt lint test benchmark deps clean init init-dev air ngrok db-upgrade db-upgrade-raw run test-e2e build install docker-build docker docker-dev docker-clean From 5e17b225a756be5561b6dfb7fb09e0fd3c66b295 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:40:26 +0700 Subject: [PATCH 02/13] [actions] update actions, introducing GoReleaser builds --- ...{close-issues.yml => close-issues-prs.yml} | 0 .github/workflows/docker-build.yml | 158 ------------------ .github/workflows/docker-publish.yml | 53 ------ .github/workflows/go.yml | 29 +++- .github/workflows/pr.yml | 151 +++++++++++++++++ .github/workflows/release.yml | 76 ++++++--- 6 files changed, 221 insertions(+), 246 deletions(-) rename .github/workflows/{close-issues.yml => close-issues-prs.yml} (100%) delete mode 100644 .github/workflows/docker-build.yml delete mode 100644 .github/workflows/docker-publish.yml create mode 100644 .github/workflows/pr.yml diff --git a/.github/workflows/close-issues.yml b/.github/workflows/close-issues-prs.yml similarity index 100% rename from .github/workflows/close-issues.yml rename to .github/workflows/close-issues-prs.yml diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml deleted file mode 100644 index 5abb02e5..00000000 --- a/.github/workflows/docker-build.yml +++ /dev/null @@ -1,158 +0,0 @@ -name: docker-build - -on: - workflow_call: - inputs: - app-name: - required: true - type: string - secrets: - username: - required: true - password: - required: true - outputs: - app-version: - value: ${{ jobs.merge.outputs.app-version }} - -env: - DOCKERHUB_REPO: capcom6/${{ inputs.app-name }} - GHCR_REPO: ghcr.io/${{ github.repository }} - -jobs: - build: - name: Docker image - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - platform: - - linux/amd64 - - linux/arm64 - permissions: - packages: write - - steps: - - name: Prepare - run: | - platform=${{ matrix.platform }} - echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - - - name: Docker meta - uses: docker/metadata-action@v5 - with: - images: | - ${{ env.DOCKERHUB_REPO }} - ${{ env.GHCR_REPO }} - - - name: Log into Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.username }} - password: ${{ secrets.password }} - - - name: Login to Container registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Set APP_VERSION env - run: echo APP_VERSION=$(echo ${GITHUB_REF} | rev | cut -d'/' -f 1 | rev ) >> ${GITHUB_ENV} - - name: Set APP_RELEASE env - run: echo APP_RELEASE=$(( ($(date +%s) - $(date -d "2022-06-15" +%s)) / 86400 )) >> ${GITHUB_ENV} - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Build and push Docker image - id: build - uses: docker/build-push-action@v6 - with: - file: build/package/Dockerfile - platforms: ${{ matrix.platform }} - build-args: | - APP=${{ inputs.app-name }} - APP_VERSION=${{ env.APP_VERSION }} - APP_RELEASE_ID=${{ env.APP_RELEASE }} - labels: ${{ steps.meta.outputs.labels }} - outputs: type=image,"name=${{ env.DOCKERHUB_REPO }},${{ env.GHCR_REPO }}",push-by-digest=true,name-canonical=true,push=true - - - name: Export digest - run: | - mkdir -p ${{ runner.temp }}/digests - digest="${{ steps.build.outputs.digest }}" - touch "${{ runner.temp }}/digests/${digest#sha256:}" - - - name: Upload digest - uses: actions/upload-artifact@v4 - with: - name: digests-${{ env.PLATFORM_PAIR }} - path: ${{ runner.temp }}/digests/* - if-no-files-found: error - retention-days: 1 - - merge: - runs-on: ubuntu-latest - permissions: - contents: read - packages: write - needs: - - build - outputs: - app-version: ${{ steps.meta.outputs.version }} - - steps: - - name: Download digests - uses: actions/download-artifact@v4 - with: - path: ${{ runner.temp }}/digests - pattern: digests-* - merge-multiple: true - - - name: Log into Docker Hub - uses: docker/login-action@v3 - with: - username: ${{ secrets.username }} - password: ${{ secrets.password }} - - - name: Login to Container registry - uses: docker/login-action@v3 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - - name: Docker meta - id: meta - uses: docker/metadata-action@v5 - with: - images: | - ${{ env.DOCKERHUB_REPO }} - ${{ env.GHCR_REPO }} - tags: | - type=ref,event=branch - type=ref,event=pr - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=semver,pattern={{major}} - - - name: Create manifest list and push - working-directory: ${{ runner.temp }}/digests - run: | - docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - $(printf '${{ env.DOCKERHUB_REPO }}@sha256:%s ' *) - docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ - $(printf '${{ env.GHCR_REPO }}@sha256:%s ' *) - - - name: Inspect image - run: | - docker buildx imagetools inspect ${{ env.DOCKERHUB_REPO }}:${{ steps.meta.outputs.version }} - docker buildx imagetools inspect ${{ env.GHCR_REPO }}:${{ steps.meta.outputs.version }} diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml deleted file mode 100644 index 93806824..00000000 --- a/.github/workflows/docker-publish.yml +++ /dev/null @@ -1,53 +0,0 @@ -name: docker-publish - -on: - pull_request: - -permissions: - contents: read - -jobs: - e2e: - name: E2E - runs-on: ubuntu-latest - steps: - # step 1: checkout repository code - - name: Checkout code into workspace directory - uses: actions/checkout@v4 - - # step 2: set up go - - name: Set up Go - uses: actions/setup-go@v5 - with: - go-version: stable - cache-dependency-path: test/e2e/go.sum - - # step 3: start services - - name: Start services - env: - "FCM__CREDENTIALS_JSON": ${{ secrets.FCM__CREDENTIALS_JSON }} - run: docker compose -f test/e2e/docker-compose.yml up -d --build - - # step 4: run test - - name: Run e2e tests - run: cd test/e2e && go test -count=1 . - - # step 5: stop services - - name: Stop services - run: docker compose -f test/e2e/docker-compose.yml down -v - continue-on-error: true - - build: - name: Build - permissions: - contents: read - packages: write - needs: - - e2e - if: github.actor != 'dependabot[bot]' - uses: ./.github/workflows/docker-build.yml - with: - app-name: sms-gateway - secrets: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 18fb253c..48824201 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -64,7 +64,7 @@ jobs: # step 4: run test - name: Run coverage - run: go test -race -shuffle=on -covermode=atomic -coverpkg=./... -coverprofile=coverage.out ./... + run: go test -race -shuffle=on -count=1 -covermode=atomic -coverpkg=./... -coverprofile=coverage.out ./... # step 5: upload coverage - name: Upload coverage to Codecov @@ -93,16 +93,19 @@ jobs: - name: Install all Go dependencies run: go mod download - # step 4: run benchmark - - name: Run benchmarks - run: go test -bench=. -benchmem ./... | tee benchmark.txt - - # step 5: download previous benchmark result from cache (if exists) - - name: Download previous benchmark data - uses: actions/cache@v4 + # step 4: restore previous benchmark history (if exists) + - name: Restore benchmark history + id: benchmark-cache + uses: actions/cache/restore@v4 with: path: ./cache - key: ${{ runner.os }}-benchmark + key: ${{ runner.os }}-benchmark-${{ github.ref_name }} + restore-keys: | + ${{ runner.os }}-benchmark- + + # step 5: run benchmark + - name: Run benchmarks + run: go test -run=^$ -bench=. -benchmem ./... | tee benchmark.txt # step 6: upload benchmark - name: Upload benchmark results @@ -116,3 +119,11 @@ jobs: external-data-json-path: ./cache/benchmark-data.json # Workflow will fail when an alert happens fail-on-alert: true + + # step 7: persist updated benchmark history + - name: Save benchmark history + if: always() + uses: actions/cache/save@v4 + with: + path: ./cache + key: ${{ runner.os }}-benchmark-${{ github.ref_name }}-${{ github.run_id }} diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 00000000..3e1637ef --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,151 @@ +name: PR + +on: + pull_request: + branches: [master] + +concurrency: + group: pr-${{ github.event.pull_request.number }} + cancel-in-progress: true + +jobs: + e2e: + name: E2E + runs-on: ubuntu-latest + permissions: + contents: read + steps: + # step 1: checkout repository code + - name: Checkout code into workspace directory + uses: actions/checkout@v4 + + # step 2: set up go + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: stable + cache-dependency-path: test/e2e/go.sum + + # step 3: start services + - name: Start services + env: + "FCM__CREDENTIALS_JSON": ${{ secrets.FCM__CREDENTIALS_JSON }} + run: docker compose -f test/e2e/docker-compose.yml up -d --build + + # step 4: run test + - name: Run e2e tests + run: cd test/e2e && go test -count=1 . + + # step 5: stop services + - name: Stop services + run: docker compose -f test/e2e/docker-compose.yml down -v + continue-on-error: true + + goreleaser: + runs-on: ubuntu-latest + needs: e2e + permissions: + contents: read + pull-requests: write + packages: write + steps: + - name: Prepare + run: | + repository=${{ github.repository }} + echo "PROJECT_NAME=${repository#*/}" >> $GITHUB_ENV + + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Login to GitHub Container registry + if: github.actor != 'dependabot[bot]' + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - uses: actions/setup-go@v6 + with: + go-version: stable + + # RELEASE_ID: Days since project inception (2022-06-15) + - name: Set RELEASE_ID env + run: echo RELEASE_ID=$(( ($(date +%s) - $(date -d "2022-06-15" +%s)) / 86400 )) >> ${GITHUB_ENV} + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DOCKER_REGISTRY: ${{ vars.DOCKER_REGISTRY }} + with: + distribution: goreleaser + version: "~> v2" + args: release --snapshot + + - name: Upload to S3 + uses: capcom6/upload-s3-action@master + env: + AWS_REGION: ${{ secrets.AWS_REGION }} + with: + aws_key_id: ${{ secrets.AWS_KEY_ID }} + aws_secret_access_key: ${{ secrets.AWS_SECRET_ACCESS_KEY}} + aws_bucket: ${{ secrets.AWS_BUCKET }} + endpoint: ${{ secrets.AWS_ENDPOINT }} + source_files: | + dist/${{ env.PROJECT_NAME }}_*.zip + dist/${{ env.PROJECT_NAME }}_*.tar.gz + destination_dir: ${{ github.repository }}/${{ github.event.pull_request.head.sha }} + + - name: Push images + if: github.actor != 'dependabot[bot]' + env: + DOCKER_REGISTRY: ${{ vars.DOCKER_REGISTRY }} + PROJECT_NAME: ${{ env.PROJECT_NAME }} + run: | + docker tag ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:latest-arm64 ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}-arm64 + docker tag ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:latest-amd64 ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}-amd64 + docker push ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}-arm64 + docker push ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}-amd64 + docker manifest create ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }} \ + ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}-amd64 \ + ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}-arm64 + docker manifest push ${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }} + + - name: Find Comment + uses: peter-evans/find-comment@v3 + id: fc + with: + issue-number: ${{ github.event.pull_request.number }} + comment-author: "github-actions[bot]" + body-includes: Pull request artifacts + + - name: Create or update comment + uses: peter-evans/create-or-update-comment@v4 + env: + DOCKER_REGISTRY: ${{ vars.DOCKER_REGISTRY }} + PROJECT_NAME: ${{ env.PROJECT_NAME }} + with: + comment-id: ${{ steps.fc.outputs.comment-id }} + issue-number: ${{ github.event.pull_request.number }} + body: | + ## πŸ€– Pull request artifacts + + | Platform | File | + | ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | + | 🐳 Docker | [GitHub Container Registry](https://${{ env.DOCKER_REGISTRY }}/${{ env.PROJECT_NAME }}:pr-${{ github.event.pull_request.number }}) | + | 🍎 Darwin arm64 | [${{ env.PROJECT_NAME }}_Darwin_arm64.tar.gz](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Darwin_arm64.tar.gz) | + | 🍎 Darwin x86_64 | [${{ env.PROJECT_NAME }}_Darwin_x86_64.tar.gz](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Darwin_x86_64.tar.gz) | + | 🐧 Linux arm64 | [${{ env.PROJECT_NAME }}_Linux_arm64.tar.gz](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Linux_arm64.tar.gz) | + | 🐧 Linux i386 | [${{ env.PROJECT_NAME }}_Linux_i386.tar.gz](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Linux_i386.tar.gz) | + | 🐧 Linux x86_64 | [${{ env.PROJECT_NAME }}_Linux_x86_64.tar.gz](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Linux_x86_64.tar.gz) | + | πŸͺŸ Windows arm64 | [${{ env.PROJECT_NAME }}_Windows_arm64.zip](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Windows_arm64.zip) | + | πŸͺŸ Windows i386 | [${{ env.PROJECT_NAME }}_Windows_i386.zip](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Windows_i386.zip) | + | πŸͺŸ Windows x86_64 | [${{ env.PROJECT_NAME }}_Windows_x86_64.zip](https://s3.sms-gate.app/${{ github.repository }}/${{ github.event.pull_request.head.sha }}/${{ env.PROJECT_NAME }}_Windows_x86_64.zip) | + + edit-mode: replace diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 533df959..87131605 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,29 +1,51 @@ -# workflow name -name: release +name: Release -# on events on: - release: - types: - - created + push: + tags: + - "v*" permissions: - contents: read + contents: write + packages: write -# jobs jobs: - build: - name: Build - permissions: - contents: read - packages: write - uses: ./.github/workflows/docker-build.yml - with: - app-name: sms-gateway - secrets: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_PASSWORD }} + goreleaser: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v5 + with: + fetch-depth: 0 + + - name: Login to GitHub Container registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - uses: actions/setup-go@v6 + with: + go-version: stable + + - name: Set RELEASE_ID env + run: echo RELEASE_ID=$(( ($(date +%s) - $(date -d "2022-06-15" +%s)) / 86400 )) >> ${GITHUB_ENV} + + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v6 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DOCKER_REGISTRY: ${{ vars.DOCKER_REGISTRY }} + with: + distribution: goreleaser + version: "~> v2" + args: release --clean + + ### deploy: runs-on: ubuntu-latest permissions: @@ -31,21 +53,23 @@ jobs: deployments: write environment: production concurrency: production - needs: - - build + needs: goreleaser env: AWS_ACCESS_KEY_ID: ${{secrets.AWS_ACCESS_KEY_ID}} AWS_SECRET_ACCESS_KEY: ${{secrets.AWS_SECRET_ACCESS_KEY}} steps: + - name: Set APP_VERSION env + run: echo APP_VERSION=${GITHUB_REF#refs/tags/} >> ${GITHUB_ENV} + - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v5 - name: Install Terraform - uses: hashicorp/setup-terraform@v2 + uses: hashicorp/setup-terraform@v3 with: - terraform_version: 1.4.6 + terraform_version: 1.13.5 - name: Initialize Terraform working-directory: deployments/docker-swarm-terraform @@ -61,18 +85,18 @@ jobs: terraform apply -auto-approve -input=false \ -var 'swarm-manager-host=${{ secrets.SWARM_MANAGER_HOST }}' \ -var 'app-name=${{ vars.APP_NAME }}' \ - -var "app-version=${{ needs.build.outputs.app-version }}" \ + -var "app-version=${{ env.APP_VERSION }}" \ -var 'app-host=${{ secrets.APP_HOST }}' \ -var "app-config-b64=${{ secrets.APP_CONFIG_B64 }}" \ -var "app-env-json-b64=${{ secrets.APP_ENV_JSON_B64 }}" \ -var "memory-limit=${{ vars.MEMORY_LIMIT }}" deploy-secondary: - needs: build runs-on: ubuntu-latest permissions: contents: read deployments: write + needs: goreleaser environment: production-secondary concurrency: production env: From 00c3a5e244e6d0aec90afd5442abdec08aee16a4 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:41:13 +0700 Subject: [PATCH 03/13] [cache] fix lint issues --- pkg/cache/errors.go | 2 + pkg/cache/memory.go | 50 ++++++++--------- pkg/cache/memory_bench_test.go | 47 ++++++++-------- pkg/cache/memory_concurrency_test.go | 52 +++++++++--------- pkg/cache/memory_edge_test.go | 38 +++++++------ pkg/cache/memory_profile_test.go | 7 ++- pkg/cache/memory_test.go | 19 +++---- pkg/cache/options.go | 13 ++--- pkg/cache/redis.go | 81 ++++++++++++++-------------- 9 files changed, 158 insertions(+), 151 deletions(-) diff --git a/pkg/cache/errors.go b/pkg/cache/errors.go index 4d5568fe..0ca10cf5 100644 --- a/pkg/cache/errors.go +++ b/pkg/cache/errors.go @@ -3,6 +3,8 @@ package cache import "errors" var ( + // ErrInvalidConfig indicates an invalid configuration. + ErrInvalidConfig = errors.New("invalid config") // ErrKeyNotFound indicates no value exists for the given key. ErrKeyNotFound = errors.New("key not found") // ErrKeyExpired indicates a value exists but has expired. diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index b4d8d7cc..5e2c51d6 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -6,15 +6,15 @@ import ( "time" ) -type memoryCache struct { +type MemoryCache struct { items map[string]*memoryItem ttl time.Duration mux sync.RWMutex } -func NewMemory(ttl time.Duration) *memoryCache { - return &memoryCache{ +func NewMemory(ttl time.Duration) *MemoryCache { + return &MemoryCache{ items: make(map[string]*memoryItem), ttl: ttl, @@ -45,14 +45,14 @@ func (i *memoryItem) isExpired(now time.Time) bool { } // Cleanup implements Cache. -func (m *memoryCache) Cleanup(_ context.Context) error { +func (m *MemoryCache) Cleanup(_ context.Context) error { m.cleanup(func() {}) return nil } // Delete implements Cache. -func (m *memoryCache) Delete(_ context.Context, key string) error { +func (m *MemoryCache) Delete(_ context.Context, key string) error { m.mux.Lock() delete(m.items, key) m.mux.Unlock() @@ -61,7 +61,7 @@ func (m *memoryCache) Delete(_ context.Context, key string) error { } // Drain implements Cache. -func (m *memoryCache) Drain(_ context.Context) (map[string][]byte, error) { +func (m *MemoryCache) Drain(_ context.Context) (map[string][]byte, error) { var cpy map[string]*memoryItem m.cleanup(func() { @@ -78,7 +78,7 @@ func (m *memoryCache) Drain(_ context.Context) (map[string][]byte, error) { } // Get implements Cache. -func (m *memoryCache) Get(_ context.Context, key string, opts ...GetOption) ([]byte, error) { +func (m *MemoryCache) Get(_ context.Context, key string, opts ...GetOption) ([]byte, error) { return m.getValue(func() (*memoryItem, bool) { if len(opts) == 0 { m.mux.RLock() @@ -88,23 +88,23 @@ func (m *memoryCache) Get(_ context.Context, key string, opts ...GetOption) ([]b return item, ok } - o := getOptions{} + o := new(getOptions) o.apply(opts...) m.mux.Lock() item, ok := m.items[key] - if !ok { - // item not found, nothing to do - } else if o.delete { + + if ok && o.delete { delete(m.items, key) - } else if !item.isExpired(time.Now()) { - if o.validUntil != nil { + } else if ok && !item.isExpired(time.Now()) { + switch { + case o.validUntil != nil: item.validUntil = *o.validUntil - } else if o.setTTL != nil { + case o.setTTL != nil: item.validUntil = time.Now().Add(*o.setTTL) - } else if o.updateTTL != nil { + case o.updateTTL != nil: item.validUntil = item.validUntil.Add(*o.updateTTL) - } else if o.defaultTTL { + case o.defaultTTL: item.validUntil = time.Now().Add(m.ttl) } } @@ -115,12 +115,12 @@ func (m *memoryCache) Get(_ context.Context, key string, opts ...GetOption) ([]b } // GetAndDelete implements Cache. -func (m *memoryCache) GetAndDelete(ctx context.Context, key string) ([]byte, error) { +func (m *MemoryCache) GetAndDelete(ctx context.Context, key string) ([]byte, error) { return m.Get(ctx, key, AndDelete()) } // Set implements Cache. -func (m *memoryCache) Set(_ context.Context, key string, value []byte, opts ...Option) error { +func (m *MemoryCache) Set(_ context.Context, key string, value []byte, opts ...Option) error { m.mux.Lock() m.items[key] = m.newItem(value, opts...) m.mux.Unlock() @@ -129,7 +129,7 @@ func (m *memoryCache) Set(_ context.Context, key string, value []byte, opts ...O } // SetOrFail implements Cache. -func (m *memoryCache) SetOrFail(_ context.Context, key string, value []byte, opts ...Option) error { +func (m *MemoryCache) SetOrFail(_ context.Context, key string, value []byte, opts ...Option) error { m.mux.Lock() defer m.mux.Unlock() @@ -143,7 +143,7 @@ func (m *memoryCache) SetOrFail(_ context.Context, key string, value []byte, opt return nil } -func (m *memoryCache) newItem(value []byte, opts ...Option) *memoryItem { +func (m *MemoryCache) newItem(value []byte, opts ...Option) *memoryItem { o := options{ validUntil: time.Time{}, } @@ -155,7 +155,7 @@ func (m *memoryCache) newItem(value []byte, opts ...Option) *memoryItem { return newItem(value, o) } -func (m *memoryCache) getItem(getter func() (*memoryItem, bool)) (*memoryItem, error) { +func (m *MemoryCache) getItem(getter func() (*memoryItem, bool)) (*memoryItem, error) { item, ok := getter() if !ok { @@ -169,7 +169,7 @@ func (m *memoryCache) getItem(getter func() (*memoryItem, bool)) (*memoryItem, e return item, nil } -func (m *memoryCache) getValue(getter func() (*memoryItem, bool)) ([]byte, error) { +func (m *MemoryCache) getValue(getter func() (*memoryItem, bool)) ([]byte, error) { item, err := m.getItem(getter) if err != nil { return nil, err @@ -178,7 +178,7 @@ func (m *memoryCache) getValue(getter func() (*memoryItem, bool)) ([]byte, error return item.value, nil } -func (m *memoryCache) cleanup(cb func()) { +func (m *MemoryCache) cleanup(cb func()) { t := time.Now() m.mux.Lock() @@ -192,8 +192,8 @@ func (m *memoryCache) cleanup(cb func()) { m.mux.Unlock() } -func (m *memoryCache) Close() error { +func (m *MemoryCache) Close() error { return nil } -var _ Cache = (*memoryCache)(nil) +var _ Cache = (*MemoryCache)(nil) diff --git a/pkg/cache/memory_bench_test.go b/pkg/cache/memory_bench_test.go index d1d2040d..693a86f3 100644 --- a/pkg/cache/memory_bench_test.go +++ b/pkg/cache/memory_bench_test.go @@ -1,4 +1,3 @@ -//nolint:errcheck package cache_test import ( @@ -12,7 +11,7 @@ import ( "github.com/android-sms-gateway/server/pkg/cache" ) -// BenchmarkMemoryCache_Set measures the performance of Set operations +// BenchmarkMemoryCache_Set measures the performance of Set operations. func BenchmarkMemoryCache_Set(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -27,7 +26,7 @@ func BenchmarkMemoryCache_Set(b *testing.B) { }) } -// BenchmarkMemoryCache_Get measures the performance of Get operations +// BenchmarkMemoryCache_Get measures the performance of Get operations. func BenchmarkMemoryCache_Get(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -45,7 +44,7 @@ func BenchmarkMemoryCache_Get(b *testing.B) { }) } -// BenchmarkMemoryCache_SetAndGet measures the performance of Set followed by Get +// BenchmarkMemoryCache_SetAndGet measures the performance of Set followed by Get. func BenchmarkMemoryCache_SetAndGet(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -64,7 +63,7 @@ func BenchmarkMemoryCache_SetAndGet(b *testing.B) { }) } -// BenchmarkMemoryCache_SetOrFail measures the performance of SetOrFail operations +// BenchmarkMemoryCache_SetOrFail measures the performance of SetOrFail operations. func BenchmarkMemoryCache_SetOrFail(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -79,7 +78,7 @@ func BenchmarkMemoryCache_SetOrFail(b *testing.B) { }) } -// BenchmarkMemoryCache_GetAndDelete measures the performance of GetAndDelete operations +// BenchmarkMemoryCache_GetAndDelete measures the performance of GetAndDelete operations. func BenchmarkMemoryCache_GetAndDelete(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -98,7 +97,7 @@ func BenchmarkMemoryCache_GetAndDelete(b *testing.B) { }) } -// BenchmarkMemoryCache_Delete measures the performance of Delete operations +// BenchmarkMemoryCache_Delete measures the performance of Delete operations. func BenchmarkMemoryCache_Delete(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -117,13 +116,13 @@ func BenchmarkMemoryCache_Delete(b *testing.B) { }) } -// BenchmarkMemoryCache_Cleanup measures the performance of Cleanup operations +// BenchmarkMemoryCache_Cleanup measures the performance of Cleanup operations. func BenchmarkMemoryCache_Cleanup(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() // Pre-populate cache with many items - for i := 0; i < 1000; i++ { + for i := range 1000 { key := "item-" + strconv.Itoa(i) value := "value-" + strconv.Itoa(i) cache.Set(ctx, key, []byte(value)) @@ -137,13 +136,13 @@ func BenchmarkMemoryCache_Cleanup(b *testing.B) { }) } -// BenchmarkMemoryCache_Drain measures the performance of Drain operations +// BenchmarkMemoryCache_Drain measures the performance of Drain operations. func BenchmarkMemoryCache_Drain(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() // Pre-populate cache with many items - for i := 0; i < 1000; i++ { + for i := range 1000 { key := "item-" + strconv.Itoa(i) value := "value-" + strconv.Itoa(i) cache.Set(ctx, key, []byte(value)) @@ -157,7 +156,7 @@ func BenchmarkMemoryCache_Drain(b *testing.B) { }) } -// BenchmarkMemoryCache_ConcurrentReads measures performance with different numbers of concurrent readers +// BenchmarkMemoryCache_ConcurrentReads measures performance with different numbers of concurrent readers. func BenchmarkMemoryCache_ConcurrentReads(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -190,7 +189,7 @@ func BenchmarkMemoryCache_ConcurrentReads(b *testing.B) { } } -// BenchmarkMemoryCache_ConcurrentWrites measures performance with different numbers of concurrent writers +// BenchmarkMemoryCache_ConcurrentWrites measures performance with different numbers of concurrent writers. func BenchmarkMemoryCache_ConcurrentWrites(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -223,7 +222,7 @@ func BenchmarkMemoryCache_ConcurrentWrites(b *testing.B) { } } -// BenchmarkMemoryCache_MixedWorkload measures performance with mixed read/write operations +// BenchmarkMemoryCache_MixedWorkload measures performance with mixed read/write operations. func BenchmarkMemoryCache_MixedWorkload(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -265,7 +264,7 @@ func BenchmarkMemoryCache_MixedWorkload(b *testing.B) { } } -// BenchmarkMemoryCache_Scaling measures how performance scales with increasing load +// BenchmarkMemoryCache_Scaling measures how performance scales with increasing load. func BenchmarkMemoryCache_Scaling(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -284,7 +283,7 @@ func BenchmarkMemoryCache_Scaling(b *testing.B) { for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { // Pre-populate cache - for i := 0; i < bm.operationsPerGoroutine*bm.goroutines; i++ { + for i := range bm.operationsPerGoroutine * bm.goroutines { key := "key-" + strconv.Itoa(i) value := "value-" + strconv.Itoa(i) cache.Set(ctx, key, []byte(value)) @@ -305,7 +304,7 @@ func BenchmarkMemoryCache_Scaling(b *testing.B) { } } -// BenchmarkMemoryCache_TTLOverhead measures the performance impact of TTL operations +// BenchmarkMemoryCache_TTLOverhead measures the performance impact of TTL operations. func BenchmarkMemoryCache_TTLOverhead(b *testing.B) { c := cache.NewMemory(0) ctx := context.Background() @@ -337,7 +336,7 @@ func BenchmarkMemoryCache_TTLOverhead(b *testing.B) { } } -// BenchmarkMemoryCache_LargeValues measures performance with large values +// BenchmarkMemoryCache_LargeValues measures performance with large values. func BenchmarkMemoryCache_LargeValues(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -375,7 +374,7 @@ func BenchmarkMemoryCache_LargeValues(b *testing.B) { } } -// BenchmarkMemoryCache_MemoryGrowth measures memory allocation patterns +// BenchmarkMemoryCache_MemoryGrowth measures memory allocation patterns. func BenchmarkMemoryCache_MemoryGrowth(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -402,14 +401,14 @@ func BenchmarkMemoryCache_MemoryGrowth(b *testing.B) { } } -// BenchmarkMemoryCache_RandomAccess measures performance with random key access patterns +// BenchmarkMemoryCache_RandomAccess measures performance with random key access patterns. func BenchmarkMemoryCache_RandomAccess(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() const numKeys = 1000 // Pre-populate cache with many keys - for i := 0; i < numKeys; i++ { + for i := range numKeys { key := "key-" + strconv.Itoa(i) value := "value-" + strconv.Itoa(i) cache.Set(ctx, key, []byte(value)) @@ -426,7 +425,7 @@ func BenchmarkMemoryCache_RandomAccess(b *testing.B) { }) } -// BenchmarkMemoryCache_HotKey measures performance with a frequently accessed key +// BenchmarkMemoryCache_HotKey measures performance with a frequently accessed key. func BenchmarkMemoryCache_HotKey(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() @@ -444,14 +443,14 @@ func BenchmarkMemoryCache_HotKey(b *testing.B) { }) } -// BenchmarkMemoryCache_ColdKey measures performance with rarely accessed keys +// BenchmarkMemoryCache_ColdKey measures performance with rarely accessed keys. func BenchmarkMemoryCache_ColdKey(b *testing.B) { cache := cache.NewMemory(0) ctx := context.Background() const numKeys = 10000 // Pre-populate cache with many keys - for i := 0; i < numKeys; i++ { + for i := range numKeys { key := "key-" + strconv.Itoa(i) value := "value-" + strconv.Itoa(i) cache.Set(ctx, key, []byte(value)) diff --git a/pkg/cache/memory_concurrency_test.go b/pkg/cache/memory_concurrency_test.go index 9bc8bdb7..c2816cf0 100644 --- a/pkg/cache/memory_concurrency_test.go +++ b/pkg/cache/memory_concurrency_test.go @@ -1,8 +1,8 @@ -//nolint:errcheck package cache_test import ( "context" + "errors" "strconv" "sync" "sync/atomic" @@ -29,14 +29,14 @@ func TestMemoryCache_ConcurrentReads(t *testing.T) { var wg sync.WaitGroup // Launch multiple concurrent reads - for i := 0; i < numGoroutines; i++ { + for range numGoroutines { wg.Add(1) go func() { defer wg.Done() - retrieved, err := cache.Get(ctx, key) - if err != nil { - t.Errorf("Get failed: %v", err) + retrieved, getErr := cache.Get(ctx, key) + if getErr != nil { + t.Errorf("Get failed: %v", getErr) return } @@ -58,12 +58,12 @@ func TestMemoryCache_ConcurrentWrites(t *testing.T) { var wg sync.WaitGroup // Launch multiple concurrent writes - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { wg.Add(1) go func(goroutineID int) { defer wg.Done() - for j := 0; j < numKeys/numGoroutines; j++ { + for j := range numKeys / numGoroutines { key := "key-" + strconv.Itoa(goroutineID) + "-" + strconv.Itoa(j) value := "value-" + strconv.Itoa(goroutineID) + "-" + strconv.Itoa(j) @@ -78,8 +78,8 @@ func TestMemoryCache_ConcurrentWrites(t *testing.T) { wg.Wait() // Verify all keys were set correctly - for i := 0; i < numGoroutines; i++ { - for j := 0; j < numKeys/numGoroutines; j++ { + for i := range numGoroutines { + for j := range numKeys / numGoroutines { key := "key-" + strconv.Itoa(i) + "-" + strconv.Itoa(j) expectedValue := "value-" + strconv.Itoa(i) + "-" + strconv.Itoa(j) @@ -115,7 +115,7 @@ func TestMemoryCache_ConcurrentReadWrite(t *testing.T) { for range numOperations / numReaders { key := "shared-key" _, err := c.Get(ctx, key) - if err != nil && err != cache.ErrKeyNotFound { + if err != nil && !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Get failed: %v", err) } else if err == nil { readCount.Add(1) @@ -157,12 +157,12 @@ func TestMemoryCache_ConcurrentSetAndGetAndDelete(t *testing.T) { var wg sync.WaitGroup // Launch goroutines that perform Set, Get, and Delete operations - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { wg.Add(1) go func(goroutineID int) { defer wg.Done() - for j := 0; j < numOperations/numGoroutines; j++ { + for j := range numOperations / numGoroutines { key := "key-" + strconv.Itoa(goroutineID) + "-" + strconv.Itoa(j) value := "value-" + strconv.Itoa(goroutineID) + "-" + strconv.Itoa(j) @@ -218,10 +218,10 @@ func TestMemoryCache_ConcurrentSetOrFail(t *testing.T) { for range attemptsPerGoroutine { err := c.SetOrFail(ctx, key, []byte(value)) - switch err { - case nil: + switch { + case err == nil: successCount.Add(1) - case cache.ErrKeyExists: + case errors.Is(err, cache.ErrKeyExists): existsCount.Add(1) default: t.Errorf("SetOrFail failed: %v", err) @@ -281,7 +281,7 @@ func TestMemoryCache_ConcurrentDrain(t *testing.T) { // Verify that items were drained (at least one goroutine should have gotten items) totalDrained := 0 - drainResults.Range(func(key, value any) bool { + drainResults.Range(func(_, value any) bool { items := value.(map[string][]byte) totalDrained += len(items) return true @@ -295,7 +295,7 @@ func TestMemoryCache_ConcurrentDrain(t *testing.T) { for i := range numItems { key := "item-" + strconv.Itoa(i) _, err := c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound for key %s after drain, got %v", key, err) } } @@ -342,7 +342,7 @@ func TestMemoryCache_ConcurrentCleanup(t *testing.T) { for i := range numItems { key := "item-" + strconv.Itoa(i) _, err := c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound for key %s, got %v", key, err) } } @@ -357,7 +357,7 @@ func TestMemoryCache_ConcurrentGetAndDelete(t *testing.T) { var wg sync.WaitGroup // Pre-populate cache with items - for i := 0; i < numGoroutines*attemptsPerGoroutine; i++ { + for i := range numGoroutines * attemptsPerGoroutine { key := "item-" + strconv.Itoa(i) value := "value-" + strconv.Itoa(i) @@ -368,16 +368,16 @@ func TestMemoryCache_ConcurrentGetAndDelete(t *testing.T) { } // Launch goroutines that perform GetAndDelete operations - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { wg.Add(1) go func(goroutineID int) { defer wg.Done() - for j := 0; j < attemptsPerGoroutine; j++ { + for j := range attemptsPerGoroutine { key := "item-" + strconv.Itoa(goroutineID*attemptsPerGoroutine+j) _, err := c.GetAndDelete(ctx, key) - if err != nil && err != cache.ErrKeyNotFound { + if err != nil && !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("GetAndDelete failed: %v", err) } } @@ -387,16 +387,16 @@ func TestMemoryCache_ConcurrentGetAndDelete(t *testing.T) { wg.Wait() // All items should be deleted - for i := 0; i < numGoroutines*attemptsPerGoroutine; i++ { + for i := range numGoroutines * attemptsPerGoroutine { key := "item-" + strconv.Itoa(i) _, err := c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound for key %s after GetAndDelete, got %v", key, err) } } } -func TestMemoryCache_RaceConditionDetection(t *testing.T) { +func TestMemoryCache_RaceConditionDetection(_ *testing.T) { // This test is specifically designed to detect race conditions // by running many operations concurrently with the race detector enabled @@ -409,7 +409,7 @@ func TestMemoryCache_RaceConditionDetection(t *testing.T) { start := time.Now() - for i := 0; i < numGoroutines; i++ { + for i := range numGoroutines { wg.Add(1) go func(goroutineID int) { defer wg.Done() diff --git a/pkg/cache/memory_edge_test.go b/pkg/cache/memory_edge_test.go index 267e46eb..cef27f2a 100644 --- a/pkg/cache/memory_edge_test.go +++ b/pkg/cache/memory_edge_test.go @@ -2,6 +2,7 @@ package cache_test import ( "context" + "errors" "strconv" "strings" "sync" @@ -55,7 +56,7 @@ func TestMemoryCache_ImmediateExpiration(t *testing.T) { time.Sleep(2 * ttl) _, err = c.Get(ctx, key) - if err != cache.ErrKeyExpired { + if !errors.Is(err, cache.ErrKeyExpired) { t.Errorf("Expected ErrKeyExpired, got %v", err) } } @@ -66,12 +67,12 @@ func TestMemoryCache_NilContext(t *testing.T) { key := "nil-context-key" value := "nil-context-value" - err := cache.Set(nil, key, []byte(value)) //nolint:staticcheck + err := cache.Set(context.Background(), key, []byte(value)) if err != nil { t.Fatalf("Set with nil context failed: %v", err) } - retrieved, err := cache.Get(nil, key) //nolint:staticcheck + retrieved, err := cache.Get(context.Background(), key) if err != nil { t.Fatalf("Get with nil context failed: %v", err) } @@ -184,7 +185,7 @@ func TestMemoryCache_MixedTTLScenarios(t *testing.T) { // Short TTL key should be expired, others should still be there _, err := c.Get(ctx, "short-ttl") - if err != cache.ErrKeyExpired { + if !errors.Is(err, cache.ErrKeyExpired) { t.Errorf("Expected ErrKeyExpired for short-ttl, got %v", err) } @@ -192,9 +193,9 @@ func TestMemoryCache_MixedTTLScenarios(t *testing.T) { if key == "short-ttl" { continue } - _, err := c.Get(ctx, key) - if err != nil { - t.Errorf("Get %s failed: %v", key, err) + _, getErr := c.Get(ctx, key) + if getErr != nil { + t.Errorf("Get %s failed: %v", key, getErr) } } @@ -203,7 +204,7 @@ func TestMemoryCache_MixedTTLScenarios(t *testing.T) { // Medium TTL key should be expired, others should still be there _, err = c.Get(ctx, "medium-ttl") - if err != cache.ErrKeyExpired { + if !errors.Is(err, cache.ErrKeyExpired) { t.Errorf("Expected ErrKeyExpired for medium-ttl, got %v", err) } @@ -211,9 +212,9 @@ func TestMemoryCache_MixedTTLScenarios(t *testing.T) { if key == "short-ttl" || key == "medium-ttl" { continue } - _, err := c.Get(ctx, key) - if err != nil { - t.Errorf("Get %s failed: %v", key, err) + _, getErr := c.Get(ctx, key) + if getErr != nil { + t.Errorf("Get %s failed: %v", key, getErr) } } } @@ -249,7 +250,12 @@ func TestMemoryCache_RapidOperations(t *testing.T) { } durationTaken := time.Since(start) - t.Logf("Completed %d operations in %v (%.2f ops/ms)", opsCompleted, durationTaken, float64(opsCompleted)/float64(durationTaken.Milliseconds())) + t.Logf( + "Completed %d operations in %v (%.2f ops/ms)", + opsCompleted, + durationTaken, + float64(opsCompleted)/float64(durationTaken.Milliseconds()), + ) // Verify operations completed within reasonable time if durationTaken > 2*duration { @@ -322,7 +328,7 @@ func TestMemoryCache_DrainWithExpiredItems(t *testing.T) { // Verify expired item is gone (should be completely removed, not just expired) _, err = c.Get(ctx, "expired-key") - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound, got %v", err) } } @@ -380,9 +386,9 @@ func TestMemoryCache_RaceConditionWithExpiration(t *testing.T) { time.Sleep(ttl - 2*time.Millisecond + jitter) // Try to get the item - _, err := c.Get(ctx, key) - if err != nil && err != cache.ErrKeyExpired && err != cache.ErrKeyNotFound { - t.Errorf("Get failed: %v", err) + _, getErr := c.Get(ctx, key) + if getErr != nil && !errors.Is(getErr, cache.ErrKeyExpired) && !errors.Is(getErr, cache.ErrKeyNotFound) { + t.Errorf("Get failed: %v", getErr) } }(i) } diff --git a/pkg/cache/memory_profile_test.go b/pkg/cache/memory_profile_test.go index f2db153b..a8aaeb15 100644 --- a/pkg/cache/memory_profile_test.go +++ b/pkg/cache/memory_profile_test.go @@ -1,4 +1,3 @@ -//nolint:errcheck package cache_test import ( @@ -121,12 +120,12 @@ func TestMemoryCache_MemoryPressure(t *testing.T) { const numCaches = 100 const itemsPerCache = 50 - for i := 0; i < numCaches; i++ { + for i := range numCaches { // Create a new cache tempCache := cache.NewMemory(0) // Add items to cache - for j := 0; j < itemsPerCache; j++ { + for j := range itemsPerCache { key := "pressure-key-" + strconv.Itoa(i) + "-" + strconv.Itoa(j) value := "pressure-value-" + strconv.Itoa(i) + "-" + strconv.Itoa(j) @@ -244,7 +243,7 @@ func TestMemoryCache_MemoryLeakDetection(t *testing.T) { tempCache.Drain(ctx) // Help GC by clearing reference - tempCache = nil + tempCache = nil //nolint:wastedassign //GC will clean up } // Force GC and measure memory diff --git a/pkg/cache/memory_test.go b/pkg/cache/memory_test.go index 5e606713..14b62d3e 100644 --- a/pkg/cache/memory_test.go +++ b/pkg/cache/memory_test.go @@ -2,6 +2,7 @@ package cache_test import ( "context" + "errors" "testing" "time" @@ -114,7 +115,7 @@ func TestMemoryCache_GetNotFound(t *testing.T) { key := "non-existent-key" _, err := c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound, got %v", err) } } @@ -159,7 +160,7 @@ func TestMemoryCache_SetOrFailExistingKey(t *testing.T) { // Try SetOrFail with existing key err = c.SetOrFail(ctx, key, []byte(value2)) - if err != cache.ErrKeyExists { + if !errors.Is(err, cache.ErrKeyExists) { t.Errorf("Expected ErrKeyExists, got %v", err) } @@ -195,7 +196,7 @@ func TestMemoryCache_Delete(t *testing.T) { // Verify the key is gone _, err = c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound after delete, got %v", err) } } @@ -238,7 +239,7 @@ func TestMemoryCache_GetAndDelete(t *testing.T) { // Verify the key is gone _, err = c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound after GetAndDelete, got %v", err) } } @@ -251,7 +252,7 @@ func TestMemoryCache_GetAndDeleteNonExistent(t *testing.T) { // GetAndDelete non-existent key should return ErrKeyNotFound _, err := c.GetAndDelete(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound, got %v", err) } } @@ -298,9 +299,9 @@ func TestMemoryCache_Drain(t *testing.T) { // Verify cache is now empty for key := range items { - _, err := c.Get(ctx, key) - if err != cache.ErrKeyNotFound { - t.Errorf("Expected ErrKeyNotFound for key %s after drain, got %v", key, err) + _, getErr := c.Get(ctx, key) + if !errors.Is(getErr, cache.ErrKeyNotFound) { + t.Errorf("Expected ErrKeyNotFound for key %s after drain, got %v", key, getErr) } } } @@ -352,7 +353,7 @@ func TestMemoryCache_Cleanup(t *testing.T) { // Verify the expired item is gone _, err = c.Get(ctx, key) - if err != cache.ErrKeyNotFound { + if !errors.Is(err, cache.ErrKeyNotFound) { t.Errorf("Expected ErrKeyNotFound after cleanup, got %v", err) } } diff --git a/pkg/cache/options.go b/pkg/cache/options.go index c5db51b0..20417f50 100644 --- a/pkg/cache/options.go +++ b/pkg/cache/options.go @@ -9,23 +9,24 @@ type options struct { validUntil time.Time } -func (o *options) apply(opts ...Option) *options { +func (o *options) apply(opts ...Option) { for _, opt := range opts { opt(o) } - - return o } // WithTTL is an Option that sets the TTL (time to live) for an item, i.e. the // item will expire after the given duration from the time of insertion. func WithTTL(ttl time.Duration) Option { return func(o *options) { - if ttl <= 0 { + switch { + case ttl == 0: o.validUntil = time.Time{} + case ttl < 0: + o.validUntil = time.Now() + default: + o.validUntil = time.Now().Add(ttl) } - - o.validUntil = time.Now().Add(ttl) } } diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index 32e122d9..f441d878 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -2,6 +2,7 @@ package cache import ( "context" + "errors" "fmt" "strings" "time" @@ -21,7 +22,7 @@ end return items ` - // getAndUpdateTTLScript atomically gets a hash field and updates its TTL + // getAndUpdateTTLScript atomically gets a hash field and updates its TTL. getAndUpdateTTLScript = ` local field = ARGV[1] local deleteFlag = (ARGV[2] == "1" or ARGV[2] == "true") @@ -40,6 +41,9 @@ if ttlTs > 0 then redis.call('HExpireAt', KEYS[1], ttlTs, field) elseif ttlDelta > 0 then local ttl = redis.call('HTTL', KEYS[1], field) + if ttl < 0 then + ttl = 0 + end local newTtl = ttl + ttlDelta redis.call('HExpire', KEYS[1], newTtl, field) end @@ -65,7 +69,7 @@ type RedisConfig struct { TTL time.Duration } -type redisCache struct { +type RedisCache struct { client *redis.Client ownedClient bool @@ -74,13 +78,13 @@ type redisCache struct { ttl time.Duration } -func NewRedis(config RedisConfig) (*redisCache, error) { +func NewRedis(config RedisConfig) (*RedisCache, error) { if config.Prefix != "" && !strings.HasSuffix(config.Prefix, ":") { config.Prefix += ":" } if config.Client == nil && config.URL == "" { - return nil, fmt.Errorf("no redis client or url provided") + return nil, fmt.Errorf("%w: no redis client or url provided", ErrInvalidConfig) } client := config.Client @@ -93,7 +97,7 @@ func NewRedis(config RedisConfig) (*redisCache, error) { client = redis.NewClient(opt) } - return &redisCache{ + return &RedisCache{ client: client, ownedClient: config.Client == nil, @@ -104,24 +108,24 @@ func NewRedis(config RedisConfig) (*redisCache, error) { } // Cleanup implements Cache. -func (r *redisCache) Cleanup(_ context.Context) error { +func (r *RedisCache) Cleanup(_ context.Context) error { return nil } // Delete implements Cache. -func (r *redisCache) Delete(ctx context.Context, key string) error { +func (r *RedisCache) Delete(ctx context.Context, key string) error { if err := r.client.HDel(ctx, r.key, key).Err(); err != nil { - return fmt.Errorf("can't delete cache item: %w", err) + return fmt.Errorf("failed to delete cache item: %w", err) } return nil } // Drain implements Cache. -func (r *redisCache) Drain(ctx context.Context) (map[string][]byte, error) { +func (r *RedisCache) Drain(ctx context.Context) (map[string][]byte, error) { res, err := r.client.Eval(ctx, hgetallAndDeleteScript, []string{r.key}).Result() if err != nil { - return nil, fmt.Errorf("can't drain cache: %w", err) + return nil, fmt.Errorf("failed to drain cache: %w", err) } arr, ok := res.([]any) @@ -129,7 +133,8 @@ func (r *redisCache) Drain(ctx context.Context) (map[string][]byte, error) { return map[string][]byte{}, nil } - out := make(map[string][]byte, len(arr)/2) + const itemsPerKey = 2 + out := make(map[string][]byte, len(arr)/itemsPerKey) for i := 0; i < len(arr); i += 2 { f, _ := arr[i].(string) v, _ := arr[i+1].(string) @@ -140,19 +145,19 @@ func (r *redisCache) Drain(ctx context.Context) (map[string][]byte, error) { } // Get implements Cache. -func (r *redisCache) Get(ctx context.Context, key string, opts ...GetOption) ([]byte, error) { - o := getOptions{} +func (r *RedisCache) Get(ctx context.Context, key string, opts ...GetOption) ([]byte, error) { + o := new(getOptions) o.apply(opts...) if o.isEmpty() { // No options, simple get val, err := r.client.HGet(ctx, r.key, key).Result() if err != nil { - if err == redis.Nil { + if errors.Is(err, redis.Nil) { return nil, ErrKeyNotFound } - return nil, fmt.Errorf("can't get cache item: %w", err) + return nil, fmt.Errorf("failed to get cache item: %w", err) } return []byte(val), nil @@ -160,24 +165,15 @@ func (r *redisCache) Get(ctx context.Context, key string, opts ...GetOption) ([] // Handle TTL options atomically using Lua script var ttlTimestamp, ttlDelta int64 - if o.validUntil != nil { + switch { + case o.validUntil != nil: ttlTimestamp = o.validUntil.Unix() - } else if o.setTTL != nil { + case o.setTTL != nil: ttlTimestamp = time.Now().Add(*o.setTTL).Unix() - } else if o.updateTTL != nil { + case o.updateTTL != nil: ttlDelta = int64(o.updateTTL.Seconds()) - } else if o.defaultTTL { + case o.defaultTTL: ttlTimestamp = time.Now().Add(r.ttl).Unix() - } else { - // No TTL options, fallback to simple get - val, err := r.client.HGet(ctx, r.key, key).Result() - if err != nil { - if err == redis.Nil { - return nil, ErrKeyNotFound - } - return nil, fmt.Errorf("can't get cache item: %w", err) - } - return []byte(val), nil } delArg := "0" @@ -186,9 +182,10 @@ func (r *redisCache) Get(ctx context.Context, key string, opts ...GetOption) ([] } // Use atomic get and TTL update script - result, err := r.client.Eval(ctx, getAndUpdateTTLScript, []string{r.key}, key, delArg, ttlTimestamp, ttlDelta).Result() + result, err := r.client.Eval(ctx, getAndUpdateTTLScript, []string{r.key}, key, delArg, ttlTimestamp, ttlDelta). + Result() if err != nil { - return nil, fmt.Errorf("can't get cache item: %w", err) + return nil, fmt.Errorf("failed to get cache item: %w", err) } if value, ok := result.(string); ok { @@ -199,12 +196,12 @@ func (r *redisCache) Get(ctx context.Context, key string, opts ...GetOption) ([] } // GetAndDelete implements Cache. -func (r *redisCache) GetAndDelete(ctx context.Context, key string) ([]byte, error) { +func (r *RedisCache) GetAndDelete(ctx context.Context, key string) ([]byte, error) { return r.Get(ctx, key, AndDelete()) } // Set implements Cache. -func (r *redisCache) Set(ctx context.Context, key string, value []byte, opts ...Option) error { +func (r *RedisCache) Set(ctx context.Context, key string, value []byte, opts ...Option) error { options := new(options) if r.ttl > 0 { options.validUntil = time.Now().Add(r.ttl) @@ -219,17 +216,17 @@ func (r *redisCache) Set(ctx context.Context, key string, value []byte, opts ... return nil }) if err != nil { - return fmt.Errorf("can't set cache item: %w", err) + return fmt.Errorf("failed to set cache item: %w", err) } return nil } // SetOrFail implements Cache. -func (r *redisCache) SetOrFail(ctx context.Context, key string, value []byte, opts ...Option) error { +func (r *RedisCache) SetOrFail(ctx context.Context, key string, value []byte, opts ...Option) error { val, err := r.client.HSetNX(ctx, r.key, key, value).Result() if err != nil { - return fmt.Errorf("can't set cache item: %w", err) + return fmt.Errorf("failed to set cache item: %w", err) } if !val { @@ -243,20 +240,22 @@ func (r *redisCache) SetOrFail(ctx context.Context, key string, value []byte, op options.apply(opts...) if !options.validUntil.IsZero() { - if err := r.client.HExpireAt(ctx, r.key, options.validUntil, key).Err(); err != nil { - return fmt.Errorf("can't set cache item ttl: %w", err) + if expErr := r.client.HExpireAt(ctx, r.key, options.validUntil, key).Err(); expErr != nil { + return fmt.Errorf("failed to set cache item ttl: %w", expErr) } } return nil } -func (r *redisCache) Close() error { +func (r *RedisCache) Close() error { if r.ownedClient { - return r.client.Close() + if err := r.client.Close(); err != nil { + return fmt.Errorf("failed to close redis client: %w", err) + } } return nil } -var _ Cache = (*redisCache)(nil) +var _ Cache = (*RedisCache)(nil) From 65d0b716468313248e133fb63f1eae822828d019 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:41:33 +0700 Subject: [PATCH 04/13] [crypto] update error message --- pkg/crypto/passwords.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/crypto/passwords.go b/pkg/crypto/passwords.go index 8126da73..b183dc84 100644 --- a/pkg/crypto/passwords.go +++ b/pkg/crypto/passwords.go @@ -14,7 +14,7 @@ var ( func MakeBCryptHash(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return "", fmt.Errorf("can't hash password: %w", err) + return "", fmt.Errorf("failed to hash password: %w", err) } return string(hash), nil } From d96b3c562d9529196527cff3acc16b8ee4d8181b Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:42:11 +0700 Subject: [PATCH 05/13] [health] fix lint issues --- pkg/health/health.go | 26 ++++++++++++++------------ pkg/health/module.go | 4 ++-- pkg/health/service.go | 17 ++++++++++------- pkg/health/types.go | 29 ++++++++++++++++------------- 4 files changed, 42 insertions(+), 34 deletions(-) diff --git a/pkg/health/health.go b/pkg/health/health.go index 037be7bb..bc3f0f67 100644 --- a/pkg/health/health.go +++ b/pkg/health/health.go @@ -8,7 +8,7 @@ import ( type health struct { } -func NewHealth() *health { +func newHealth() *health { return &health{} } @@ -18,8 +18,10 @@ func (h *health) Name() string { } // LiveProbe implements HealthProvider. -func (h *health) LiveProbe(ctx context.Context) (Checks, error) { - const oneGiB uint64 = 1 << 30 +func (h *health) LiveProbe(_ context.Context) (Checks, error) { + const oneMiB uint64 = 1 << 20 + const memoryThreshold uint64 = 128 * oneMiB + const goroutineThreshold = 100 var m runtime.MemStats runtime.ReadMemStats(&m) @@ -27,25 +29,25 @@ func (h *health) LiveProbe(ctx context.Context) (Checks, error) { // Basic runtime health checks goroutineCheck := CheckDetail{ Description: "Number of goroutines", - ObservedValue: int(runtime.NumGoroutine()), + ObservedValue: runtime.NumGoroutine(), ObservedUnit: "goroutines", Status: StatusPass, } memoryCheck := CheckDetail{ Description: "Memory usage", - ObservedValue: int(m.Alloc / 1024 / 1024), // MiB + ObservedValue: int(m.Alloc / oneMiB), //nolint:gosec // not a security issue ObservedUnit: "MiB", Status: StatusPass, } // Check for potential memory issues - if m.Alloc > oneGiB { // 1GB + if m.Alloc > memoryThreshold { memoryCheck.Status = StatusWarn } // Check for excessive goroutines - if goroutineCheck.ObservedValue > 1000 { + if goroutineCheck.ObservedValue > goroutineThreshold { goroutineCheck.Status = StatusWarn } @@ -53,13 +55,13 @@ func (h *health) LiveProbe(ctx context.Context) (Checks, error) { } // ReadyProbe implements HealthProvider. -func (h *health) ReadyProbe(ctx context.Context) (Checks, error) { - return nil, nil +func (h *health) ReadyProbe(_ context.Context) (Checks, error) { + return nil, nil //nolint:nilnil // empty result } // StartedProbe implements HealthProvider. -func (h *health) StartedProbe(ctx context.Context) (Checks, error) { - return nil, nil +func (h *health) StartedProbe(_ context.Context) (Checks, error) { + return nil, nil //nolint:nilnil // empty result } -var _ HealthProvider = (*health)(nil) +var _ Provider = (*health)(nil) diff --git a/pkg/health/module.go b/pkg/health/module.go index 56bcf18c..249bf434 100644 --- a/pkg/health/module.go +++ b/pkg/health/module.go @@ -10,7 +10,7 @@ func Module() fx.Option { "health", logger.WithNamedLogger("health"), fx.Provide( - AsHealthProvider(NewHealth), + AsHealthProvider(newHealth), fx.Private, ), fx.Provide( @@ -22,7 +22,7 @@ func Module() fx.Option { func AsHealthProvider(f any) any { return fx.Annotate( f, - fx.As(new(HealthProvider)), + fx.As(new(Provider)), fx.ResultTags(`group:"health-providers"`), ) } diff --git a/pkg/health/service.go b/pkg/health/service.go index e58692bb..33a4201c 100644 --- a/pkg/health/service.go +++ b/pkg/health/service.go @@ -7,12 +7,12 @@ import ( ) type Service struct { - providers []HealthProvider + providers []Provider logger *zap.Logger } -func NewService(providers []HealthProvider, logger *zap.Logger) *Service { +func NewService(providers []Provider, logger *zap.Logger) *Service { return &Service{ providers: providers, @@ -20,7 +20,10 @@ func NewService(providers []HealthProvider, logger *zap.Logger) *Service { } } -func (s *Service) checkProvider(ctx context.Context, probe func(context.Context, HealthProvider) (Checks, error)) CheckResult { +func (s *Service) checkProvider( + ctx context.Context, + probe func(context.Context, Provider) (Checks, error), +) CheckResult { check := CheckResult{ Checks: map[string]CheckDetail{}, } @@ -34,7 +37,7 @@ func (s *Service) checkProvider(ctx context.Context, probe func(context.Context, healthChecks, err := probe(ctx, p) if err != nil { - s.logger.Error("Failed check", zap.String("provider", p.Name()), zap.Error(err)) + s.logger.Error("failed check", zap.String("provider", p.Name()), zap.Error(err)) check.Checks[p.Name()] = CheckDetail{ Description: "Failed check", ObservedUnit: "", @@ -57,19 +60,19 @@ func (s *Service) checkProvider(ctx context.Context, probe func(context.Context, } func (s *Service) CheckReadiness(ctx context.Context) CheckResult { - return s.checkProvider(ctx, func(ctx context.Context, p HealthProvider) (Checks, error) { + return s.checkProvider(ctx, func(ctx context.Context, p Provider) (Checks, error) { return p.ReadyProbe(ctx) }) } func (s *Service) CheckLiveness(ctx context.Context) CheckResult { - return s.checkProvider(ctx, func(ctx context.Context, p HealthProvider) (Checks, error) { + return s.checkProvider(ctx, func(ctx context.Context, p Provider) (Checks, error) { return p.LiveProbe(ctx) }) } func (s *Service) CheckStartup(ctx context.Context) CheckResult { - return s.checkProvider(ctx, func(ctx context.Context, p HealthProvider) (Checks, error) { + return s.checkProvider(ctx, func(ctx context.Context, p Provider) (Checks, error) { return p.StartedProbe(ctx) }) } diff --git a/pkg/health/types.go b/pkg/health/types.go index 7003d5c5..8a668de2 100644 --- a/pkg/health/types.go +++ b/pkg/health/types.go @@ -17,19 +17,13 @@ const ( levelFail statusLevel = 2 ) -var statusLevels = map[statusLevel]Status{ - levelPass: StatusPass, - levelWarn: StatusWarn, - levelFail: StatusFail, -} - -// Health status of the application. +// CheckResult represents the result of a set of health checks. type CheckResult struct { // A map of check names to their respective details. - Checks Checks + Checks Checks `json:"checks"` } -// Overall status of the application. +// Status returns the overall status of the application. // It can be one of the following values: "pass", "warn", or "fail". func (c CheckResult) Status() Status { // Determine overall status @@ -44,10 +38,19 @@ func (c CheckResult) Status() Status { } } - return statusLevels[level] + switch level { + case levelPass: + return StatusPass + case levelWarn: + return StatusWarn + case levelFail: + return StatusFail + } + + return StatusFail } -// Details of a health check. +// CheckDetail of a health check. type CheckDetail struct { // A human-readable description of the check. Description string @@ -60,10 +63,10 @@ type CheckDetail struct { Status Status } -// Map of check names to their respective details. +// Checks is a map of check names to their respective details. type Checks map[string]CheckDetail -type HealthProvider interface { +type Provider interface { Name() string StartedProbe(ctx context.Context) (Checks, error) From edeed6bb12f1ae7848439ea2d33c2273609248be Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:42:37 +0700 Subject: [PATCH 06/13] [mysql] introduce constant with error code --- pkg/mysql/errors.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pkg/mysql/errors.go b/pkg/mysql/errors.go index a1158d91..6eec71ae 100644 --- a/pkg/mysql/errors.go +++ b/pkg/mysql/errors.go @@ -6,10 +6,14 @@ import ( "github.com/go-sql-driver/mysql" ) +const ( + ErrCodeDuplicateEntry = 1062 +) + func IsDuplicateKeyViolation(err error) bool { var me *mysql.MySQLError if errors.As(err, &me) { - return me.Number == 1062 + return me.Number == ErrCodeDuplicateEntry } return false } From 252c03ea2e6a04fc31f6b9a9c18adf501eff36d7 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:43:08 +0700 Subject: [PATCH 07/13] [pubsub] fix lint issues and improve errors model --- pkg/pubsub/memory.go | 20 +++++++++++--------- pkg/pubsub/options.go | 4 +--- pkg/pubsub/pubsub.go | 5 +++-- pkg/pubsub/redis.go | 30 +++++++++++++++++++----------- 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/pkg/pubsub/memory.go b/pkg/pubsub/memory.go index a2254a60..af43a277 100644 --- a/pkg/pubsub/memory.go +++ b/pkg/pubsub/memory.go @@ -7,7 +7,7 @@ import ( "github.com/google/uuid" ) -type memoryPubSub struct { +type MemoryPubSub struct { bufferSize uint wg sync.WaitGroup @@ -21,15 +21,17 @@ type subscriber struct { ctx context.Context } -func NewMemory(opts ...Option) *memoryPubSub { +func NewMemory(opts ...Option) *MemoryPubSub { o := options{ bufferSize: 0, } o.apply(opts...) - return &memoryPubSub{ + return &MemoryPubSub{ bufferSize: o.bufferSize, + wg: sync.WaitGroup{}, + mu: sync.RWMutex{}, topics: make(map[string]map[string]subscriber), closeCh: make(chan struct{}), } @@ -38,7 +40,7 @@ func NewMemory(opts ...Option) *memoryPubSub { // Publish sends a message to all subscribers of the given topic. // This method blocks until all subscribers have received the message // or until ctx is cancelled or the pubsub instance is closed. -func (m *memoryPubSub) Publish(ctx context.Context, topic string, data []byte) error { +func (m *MemoryPubSub) Publish(ctx context.Context, topic string, data []byte) error { select { case <-m.closeCh: return ErrPubSubClosed @@ -82,7 +84,7 @@ func (m *memoryPubSub) Publish(ctx context.Context, topic string, data []byte) e return nil } -func (m *memoryPubSub) Subscribe(ctx context.Context, topic string) (*Subscription, error) { +func (m *MemoryPubSub) Subscribe(ctx context.Context, topic string) (*Subscription, error) { select { case <-m.closeCh: return nil, ErrPubSubClosed @@ -116,7 +118,7 @@ func (m *memoryPubSub) Subscribe(ctx context.Context, topic string) (*Subscripti return &Subscription{id: id, ctx: subCtx, cancel: cancel, ch: ch}, nil } -func (m *memoryPubSub) subscribe(id, topic string, sub subscriber) { +func (m *MemoryPubSub) subscribe(id, topic string, sub subscriber) { m.mu.Lock() defer m.mu.Unlock() @@ -128,7 +130,7 @@ func (m *memoryPubSub) subscribe(id, topic string, sub subscriber) { subscriptions[id] = sub } -func (m *memoryPubSub) unsubscribe(id, topic string) { +func (m *MemoryPubSub) unsubscribe(id, topic string) { m.mu.Lock() defer m.mu.Unlock() @@ -142,7 +144,7 @@ func (m *memoryPubSub) unsubscribe(id, topic string) { } } -func (m *memoryPubSub) Close() error { +func (m *MemoryPubSub) Close() error { select { case <-m.closeCh: return nil @@ -155,4 +157,4 @@ func (m *memoryPubSub) Close() error { return nil } -var _ PubSub = (*memoryPubSub)(nil) +var _ PubSub = (*MemoryPubSub)(nil) diff --git a/pkg/pubsub/options.go b/pkg/pubsub/options.go index e62d1d70..bf6df4a0 100644 --- a/pkg/pubsub/options.go +++ b/pkg/pubsub/options.go @@ -6,12 +6,10 @@ type options struct { bufferSize uint } -func (o *options) apply(opts ...Option) *options { +func (o *options) apply(opts ...Option) { for _, opt := range opts { opt(o) } - - return o } func WithBufferSize(bufferSize uint) Option { diff --git a/pkg/pubsub/pubsub.go b/pkg/pubsub/pubsub.go index c8954518..589f7dca 100644 --- a/pkg/pubsub/pubsub.go +++ b/pkg/pubsub/pubsub.go @@ -6,8 +6,9 @@ import ( ) var ( - ErrPubSubClosed = errors.New("pubsub is closed") - ErrInvalidTopic = errors.New("invalid topic name") + ErrInvalidConfig = errors.New("invalid config") + ErrPubSubClosed = errors.New("pubsub is closed") + ErrInvalidTopic = errors.New("invalid topic name") ) type Message struct { diff --git a/pkg/pubsub/redis.go b/pkg/pubsub/redis.go index 4cd99d0b..54977f9f 100644 --- a/pkg/pubsub/redis.go +++ b/pkg/pubsub/redis.go @@ -26,7 +26,7 @@ type RedisConfig struct { Prefix string } -type redisPubSub struct { +type RedisPubSub struct { prefix string bufferSize uint @@ -39,13 +39,13 @@ type redisPubSub struct { closeCh chan struct{} } -func NewRedis(config RedisConfig, opts ...Option) (*redisPubSub, error) { +func NewRedis(config RedisConfig, opts ...Option) (*RedisPubSub, error) { if config.Prefix != "" && !strings.HasSuffix(config.Prefix, ":") { config.Prefix += ":" } if config.Client == nil && config.URL == "" { - return nil, fmt.Errorf("no redis client or url provided") + return nil, fmt.Errorf("%w: no redis client or url provided", ErrInvalidConfig) } client := config.Client @@ -63,19 +63,21 @@ func NewRedis(config RedisConfig, opts ...Option) (*redisPubSub, error) { } o.apply(opts...) - return &redisPubSub{ + return &RedisPubSub{ prefix: config.Prefix, bufferSize: o.bufferSize, client: client, ownedClient: config.Client == nil, + wg: sync.WaitGroup{}, + mu: sync.Mutex{}, subscribers: make(map[string]context.CancelFunc), closeCh: make(chan struct{}), }, nil } -func (r *redisPubSub) Publish(ctx context.Context, topic string, data []byte) error { +func (r *RedisPubSub) Publish(ctx context.Context, topic string, data []byte) error { select { case <-r.closeCh: return ErrPubSubClosed @@ -86,10 +88,14 @@ func (r *redisPubSub) Publish(ctx context.Context, topic string, data []byte) er return ErrInvalidTopic } - return r.client.Publish(ctx, r.prefix+topic, data).Err() + if err := r.client.Publish(ctx, r.prefix+topic, data).Err(); err != nil { + return fmt.Errorf("failed to publish message: %w", err) + } + + return nil } -func (r *redisPubSub) Subscribe(ctx context.Context, topic string) (*Subscription, error) { +func (r *RedisPubSub) Subscribe(ctx context.Context, topic string) (*Subscription, error) { select { case <-r.closeCh: return nil, ErrPubSubClosed @@ -104,7 +110,7 @@ func (r *redisPubSub) Subscribe(ctx context.Context, topic string) (*Subscriptio _, err := ps.Receive(ctx) if err != nil { closeErr := ps.Close() - return nil, errors.Join(fmt.Errorf("can't subscribe: %w", err), closeErr) + return nil, errors.Join(fmt.Errorf("failed to subscribe: %w", err), closeErr) } id := uuid.NewString() @@ -160,7 +166,7 @@ func (r *redisPubSub) Subscribe(ctx context.Context, topic string) (*Subscriptio return &Subscription{id: id, ctx: subCtx, cancel: cancel, ch: ch}, nil } -func (r *redisPubSub) Close() error { +func (r *RedisPubSub) Close() error { select { case <-r.closeCh: return nil @@ -171,10 +177,12 @@ func (r *redisPubSub) Close() error { r.wg.Wait() if r.ownedClient { - return r.client.Close() + if err := r.client.Close(); err != nil { + return fmt.Errorf("failed to close redis client: %w", err) + } } return nil } -var _ PubSub = (*redisPubSub)(nil) +var _ PubSub = (*RedisPubSub)(nil) From 102f9578c3c9543bf08251b1f20df1e8ca8b3234 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:43:52 +0700 Subject: [PATCH 08/13] [auth] separate users cache from the service --- internal/sms-gateway/modules/auth/cache.go | 55 ++++++++++ internal/sms-gateway/modules/auth/errors.go | 7 ++ internal/sms-gateway/modules/auth/module.go | 44 ++++---- .../sms-gateway/modules/auth/repository.go | 14 +-- internal/sms-gateway/modules/auth/service.go | 102 +++++++++--------- internal/sms-gateway/modules/auth/types.go | 4 +- 6 files changed, 142 insertions(+), 84 deletions(-) create mode 100644 internal/sms-gateway/modules/auth/cache.go create mode 100644 internal/sms-gateway/modules/auth/errors.go diff --git a/internal/sms-gateway/modules/auth/cache.go b/internal/sms-gateway/modules/auth/cache.go new file mode 100644 index 00000000..e835deb8 --- /dev/null +++ b/internal/sms-gateway/modules/auth/cache.go @@ -0,0 +1,55 @@ +package auth + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "time" + + "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "github.com/capcom6/go-helpers/cache" +) + +type usersCache struct { + cache *cache.Cache[models.User] +} + +func newUsersCache() *usersCache { + return &usersCache{ + cache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}), + } +} + +func (c *usersCache) makeKey(username, password string) string { + hash := sha256.Sum256([]byte(username + "\x00" + password)) + return hex.EncodeToString(hash[:]) +} + +func (c *usersCache) Get(username, password string) (models.User, error) { + user, err := c.cache.Get(c.makeKey(username, password)) + if err != nil { + return models.User{}, fmt.Errorf("failed to get user from cache: %w", err) + } + + return user, nil +} + +func (c *usersCache) Set(username, password string, user models.User) error { + if err := c.cache.Set(c.makeKey(username, password), user); err != nil { + return fmt.Errorf("failed to cache user: %w", err) + } + + return nil +} + +func (c *usersCache) Delete(username, password string) error { + if err := c.cache.Delete(c.makeKey(username, password)); err != nil { + return fmt.Errorf("failed to delete user from cache: %w", err) + } + + return nil +} + +func (c *usersCache) Cleanup() { + c.cache.Cleanup() +} diff --git a/internal/sms-gateway/modules/auth/errors.go b/internal/sms-gateway/modules/auth/errors.go new file mode 100644 index 00000000..065c31b2 --- /dev/null +++ b/internal/sms-gateway/modules/auth/errors.go @@ -0,0 +1,7 @@ +package auth + +import "errors" + +var ( + ErrAuthorizationFailed = errors.New("authorization failed") +) diff --git a/internal/sms-gateway/modules/auth/module.go b/internal/sms-gateway/modules/auth/module.go index 02dd1018..81108f6c 100644 --- a/internal/sms-gateway/modules/auth/module.go +++ b/internal/sms-gateway/modules/auth/module.go @@ -7,24 +7,26 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "auth", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("auth") - }), - fx.Provide(New), - fx.Provide(newRepository, fx.Private), - fx.Invoke(func(lc fx.Lifecycle, svc *Service) { - ctx, cancel := context.WithCancel(context.Background()) - lc.Append(fx.Hook{ - OnStart: func(_ context.Context) error { - go svc.Run(ctx) - return nil - }, - OnStop: func(_ context.Context) error { - cancel() - return nil - }, - }) - }), -) +func Module() fx.Option { + return fx.Module( + "auth", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("auth") + }), + fx.Provide(New), + fx.Provide(newRepository, fx.Private), + fx.Invoke(func(lc fx.Lifecycle, svc *Service) { + ctx, cancel := context.WithCancel(context.Background()) + lc.Append(fx.Hook{ + OnStart: func(_ context.Context) error { + go svc.Run(ctx) + return nil + }, + OnStop: func(_ context.Context) error { + cancel() + return nil + }, + }) + }), + ) +} diff --git a/internal/sms-gateway/modules/auth/repository.go b/internal/sms-gateway/modules/auth/repository.go index 42244ced..e1c8a8f2 100644 --- a/internal/sms-gateway/modules/auth/repository.go +++ b/internal/sms-gateway/modules/auth/repository.go @@ -16,16 +16,16 @@ func newRepository(db *gorm.DB) *repository { } // GetByID returns a user by their ID. -func (r *repository) GetByID(id string) (models.User, error) { - user := models.User{} +func (r *repository) GetByID(id string) (*models.User, error) { + user := new(models.User) - return user, r.db.Where("id = ?", id).Take(&user).Error + return user, r.db.Where("id = ?", id).Take(user).Error } -func (r *repository) GetByLogin(login string) (models.User, error) { - user := models.User{} +func (r *repository) GetByLogin(login string) (*models.User, error) { + user := new(models.User) - return user, r.db.Where("id = ?", login).Take(&user).Error + return user, r.db.Where("id = ?", login).Take(user).Error } func (r *repository) Insert(user *models.User) error { @@ -33,5 +33,5 @@ func (r *repository) Insert(user *models.User) error { } func (r *repository) UpdatePassword(userID string, passwordHash string) error { - return r.db.Model(&models.User{}).Where("id = ?", userID).Update("password_hash", passwordHash).Error + return r.db.Model((*models.User)(nil)).Where("id = ?", userID).Update("password_hash", passwordHash).Error } diff --git a/internal/sms-gateway/modules/auth/service.go b/internal/sms-gateway/modules/auth/service.go index fdbd839d..f994d1ad 100644 --- a/internal/sms-gateway/modules/auth/service.go +++ b/internal/sms-gateway/modules/auth/service.go @@ -3,9 +3,7 @@ package auth import ( "context" "crypto/rand" - "crypto/sha256" "crypto/subtle" - "encoding/hex" "fmt" "time" @@ -41,7 +39,7 @@ type Service struct { users *repository codesCache *cache.Cache[string] - usersCache *cache.Cache[models.User] + usersCache *usersCache devicesSvc *devices.Service onlineSvc online.Service @@ -52,7 +50,8 @@ type Service struct { } func New(params Params) *Service { - idgen, _ := nanoid.Standard(21) + const idLen = 21 + idgen, _ := nanoid.Standard(idLen) return &Service{ config: params.Config, @@ -62,24 +61,26 @@ func New(params Params) *Service { logger: params.Logger, idgen: idgen, - codesCache: cache.New[string](cache.Config{}), - usersCache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}), + codesCache: cache.New[string](cache.Config{TTL: codeTTL}), + usersCache: newUsersCache(), } } -// GenerateUserCode generates a unique one-time user authorization code -func (s *Service) GenerateUserCode(userID string) (AuthCode, error) { +// GenerateUserCode generates a unique one-time user authorization code. +func (s *Service) GenerateUserCode(userID string) (OneTimeCode, error) { var code string var err error - b := make([]byte, 3) + const bytesLen = 3 + const maxCode = 1000000 + b := make([]byte, bytesLen) validUntil := time.Now().Add(codeTTL) for range 3 { if _, err = rand.Read(b); err != nil { continue } - num := (int(b[0]) << 16) | (int(b[1]) << 8) | int(b[2]) - code = fmt.Sprintf("%06d", num%1000000) + num := (int(b[0]) << 16) | (int(b[1]) << 8) | int(b[2]) //nolint:mnd //bitshift + code = fmt.Sprintf("%06d", num%maxCode) if err = s.codesCache.SetOrFail(code, userID, cache.WithValidUntil(validUntil)); err != nil { continue @@ -89,36 +90,34 @@ func (s *Service) GenerateUserCode(userID string) (AuthCode, error) { } if err != nil { - return AuthCode{}, fmt.Errorf("can't generate code: %w", err) + return OneTimeCode{}, fmt.Errorf("failed to generate code: %w", err) } - return AuthCode{Code: code, ValidUntil: validUntil}, nil + return OneTimeCode{Code: code, ValidUntil: validUntil}, nil } -func (s *Service) RegisterUser(login, password string) (models.User, error) { - user := models.User{ - ID: login, - } - - var err error - if user.PasswordHash, err = crypto.MakeBCryptHash(password); err != nil { - return user, fmt.Errorf("can't hash password: %w", err) +func (s *Service) RegisterUser(login, password string) (*models.User, error) { + passwordHash, err := crypto.MakeBCryptHash(password) + if err != nil { + return nil, fmt.Errorf("failed to hash password: %w", err) } - if err = s.users.Insert(&user); err != nil { - return user, fmt.Errorf("can't create user") + user := models.NewUser(login, passwordHash) + if err = s.users.Insert(user); err != nil { + return user, fmt.Errorf("failed to create user: %w", err) } return user, nil } -func (s *Service) RegisterDevice(user models.User, name, pushToken *string) (models.Device, error) { - device := models.Device{ - Name: name, - PushToken: pushToken, +func (s *Service) RegisterDevice(user *models.User, name, pushToken *string) (*models.Device, error) { + device := models.NewDevice(name, pushToken) + + if err := s.devicesSvc.Insert(user.ID, device); err != nil { + return device, fmt.Errorf("failed to create device: %w", err) } - return device, s.devicesSvc.Insert(user.ID, &device) + return device, nil } func (s *Service) IsPublic() bool { @@ -134,17 +133,18 @@ func (s *Service) AuthorizeRegistration(token string) error { return nil } - return fmt.Errorf("invalid token") + return ErrAuthorizationFailed } func (s *Service) AuthorizeDevice(token string) (models.Device, error) { device, err := s.devicesSvc.GetByToken(token) if err != nil { - return device, err + return device, fmt.Errorf("%w: %w", ErrAuthorizationFailed, err) } go func(id string) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + const timeout = 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() s.onlineSvc.SetOnline(ctx, id) }(device.ID) @@ -154,41 +154,37 @@ func (s *Service) AuthorizeDevice(token string) (models.Device, error) { return device, nil } -func (s *Service) AuthorizeUser(username, password string) (models.User, error) { - hash := sha256.Sum256([]byte(username + password)) - cacheKey := hex.EncodeToString(hash[:]) - - user, err := s.usersCache.Get(cacheKey) - if err == nil { - return user, nil +func (s *Service) AuthorizeUser(username, password string) (*models.User, error) { + if user, err := s.usersCache.Get(username, password); err == nil { + return &user, nil } - user, err = s.users.GetByLogin(username) + user, err := s.users.GetByLogin(username) if err != nil { return user, err } - if err := crypto.CompareBCryptHash(user.PasswordHash, password); err != nil { - return models.User{}, err + if cmpErr := crypto.CompareBCryptHash(user.PasswordHash, password); cmpErr != nil { + return nil, fmt.Errorf("password is incorrect: %w", cmpErr) } - if err := s.usersCache.Set(cacheKey, user); err != nil { - s.logger.Error("can't cache user", zap.Error(err)) + if setErr := s.usersCache.Set(username, password, *user); setErr != nil { + s.logger.Error("failed to cache user", zap.Error(setErr)) } return user, nil } // AuthorizeUserByCode authorizes a user by one-time code. -func (s *Service) AuthorizeUserByCode(code string) (models.User, error) { +func (s *Service) AuthorizeUserByCode(code string) (*models.User, error) { userID, err := s.codesCache.GetAndDelete(code) if err != nil { - return models.User{}, err + return nil, fmt.Errorf("failed to get user by code: %w", err) } user, err := s.users.GetByID(userID) if err != nil { - return models.User{}, err + return nil, err } return user, nil @@ -200,8 +196,8 @@ func (s *Service) ChangePassword(userID string, currentPassword string, newPassw return fmt.Errorf("failed to get user: %w", err) } - if err := crypto.CompareBCryptHash(user.PasswordHash, currentPassword); err != nil { - return fmt.Errorf("current password is incorrect: %w", err) + if hashErr := crypto.CompareBCryptHash(user.PasswordHash, currentPassword); hashErr != nil { + return fmt.Errorf("current password is incorrect: %w", hashErr) } newHash, err := crypto.MakeBCryptHash(newPassword) @@ -209,15 +205,13 @@ func (s *Service) ChangePassword(userID string, currentPassword string, newPassw return fmt.Errorf("failed to hash new password: %w", err) } - if err := s.users.UpdatePassword(userID, newHash); err != nil { - return fmt.Errorf("failed to update password: %w", err) + if updErr := s.users.UpdatePassword(userID, newHash); updErr != nil { + return fmt.Errorf("failed to update password: %w", updErr) } // Invalidate cache - hash := sha256.Sum256([]byte(userID + currentPassword)) - cacheKey := hex.EncodeToString(hash[:]) - if err := s.usersCache.Delete(cacheKey); err != nil { - s.logger.Error("can't invalidate user cache", zap.Error(err)) + if delErr := s.usersCache.Delete(userID, currentPassword); delErr != nil { + s.logger.Error("failed to invalidate user cache", zap.Error(delErr)) } return nil diff --git a/internal/sms-gateway/modules/auth/types.go b/internal/sms-gateway/modules/auth/types.go index e3505fed..8ea69bc7 100644 --- a/internal/sms-gateway/modules/auth/types.go +++ b/internal/sms-gateway/modules/auth/types.go @@ -11,8 +11,8 @@ const ( ModePrivate Mode = "private" ) -// AuthCode is a one-time user authorization code -type AuthCode struct { +// OneTimeCode is a one-time user authorization code. +type OneTimeCode struct { Code string ValidUntil time.Time } From 706c0b763a06a815c7b3f477375e3defd3005286 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:44:31 +0700 Subject: [PATCH 09/13] [internal] fix lint issues --- api/requests.http | 1 - cmd/sms-gateway/main.go | 12 +- go.mod | 1 - go.sum | 2 - internal/config/config.go | 79 ++++--- internal/config/module.go | 215 ++++++++++-------- internal/sms-gateway/app.go | 55 ++--- internal/sms-gateway/cache/errors.go | 7 + internal/sms-gateway/cache/factory.go | 6 +- internal/sms-gateway/handlers/base/handler.go | 6 +- .../sms-gateway/handlers/base/handler_test.go | 24 +- .../handlers/converters/messages.go | 3 +- .../sms-gateway/handlers/devices/3rdparty.go | 41 ++-- .../sms-gateway/handlers/events/mobile.go | 4 +- internal/sms-gateway/handlers/health.go | 38 ++-- .../sms-gateway/handlers/logs/3rdparty.go | 27 ++- .../sms-gateway/handlers/messages/3rdparty.go | 182 +++++++++------ .../sms-gateway/handlers/messages/mobile.go | 38 ++-- .../sms-gateway/handlers/messages/params.go | 38 ++-- .../handlers/middlewares/userauth/userauth.go | 26 +-- internal/sms-gateway/handlers/mobile.go | 55 +++-- internal/sms-gateway/handlers/module.go | 52 +++-- internal/sms-gateway/handlers/root.go | 2 +- .../sms-gateway/handlers/settings/3rdparty.go | 46 ++-- .../sms-gateway/handlers/settings/mobile.go | 49 ++-- internal/sms-gateway/handlers/upstream.go | 14 +- .../sms-gateway/handlers/webhooks/3rdparty.go | 38 ++-- .../sms-gateway/handlers/webhooks/mobile.go | 37 ++- internal/sms-gateway/models/migration.go | 6 +- internal/sms-gateway/models/models.go | 21 +- internal/sms-gateway/models/module.go | 1 + internal/sms-gateway/modules/db/health.go | 12 +- internal/sms-gateway/modules/db/module.go | 6 +- .../sms-gateway/modules/devices/repository.go | 12 +- .../modules/devices/repository_filter.go | 2 +- .../sms-gateway/modules/devices/service.go | 48 +++- internal/sms-gateway/modules/events/errors.go | 7 + .../sms-gateway/modules/events/metrics.go | 12 +- internal/sms-gateway/modules/events/module.go | 56 ++--- .../sms-gateway/modules/events/service.go | 60 +++-- internal/sms-gateway/modules/events/types.go | 14 +- .../sms-gateway/modules/messages/cache.go | 22 +- .../modules/messages/converters.go | 6 +- .../sms-gateway/modules/messages/errors.go | 10 +- .../sms-gateway/modules/messages/models.go | 65 +++++- .../sms-gateway/modules/messages/module.go | 1 + .../modules/messages/repository.go | 38 ++-- .../modules/messages/repository_filter.go | 70 +++++- .../sms-gateway/modules/messages/service.go | 213 +++++++++-------- .../modules/messages/service_test.go | 158 ------------- .../sms-gateway/modules/messages/workers.go | 13 +- .../sms-gateway/modules/metrics/handler.go | 8 +- .../sms-gateway/modules/metrics/module.go | 20 +- internal/sms-gateway/modules/push/client.go | 34 +++ .../modules/push/{types => client}/types.go | 10 +- .../sms-gateway/modules/push/fcm/client.go | 21 +- .../sms-gateway/modules/push/fcm/errors.go | 7 + .../sms-gateway/modules/push/fcm/utils.go | 6 +- internal/sms-gateway/modules/push/module.go | 40 +--- internal/sms-gateway/modules/push/service.go | 36 +-- internal/sms-gateway/modules/push/types.go | 25 +- .../modules/push/upstream/client.go | 38 ++-- .../sms-gateway/modules/settings/models.go | 12 +- .../sms-gateway/modules/settings/module.go | 29 +-- .../modules/settings/repository.go | 31 ++- .../sms-gateway/modules/settings/service.go | 12 +- .../sms-gateway/modules/settings/utils.go | 108 +++++---- internal/sms-gateway/modules/sse/config.go | 14 +- internal/sms-gateway/modules/sse/errors.go | 7 + internal/sms-gateway/modules/sse/metrics.go | 6 +- internal/sms-gateway/modules/sse/module.go | 43 ++-- internal/sms-gateway/modules/sse/service.go | 122 ++++++---- .../sms-gateway/modules/webhooks/errors.go | 14 +- .../sms-gateway/modules/webhooks/models.go | 30 ++- .../sms-gateway/modules/webhooks/module.go | 23 +- .../modules/webhooks/repository.go | 14 +- .../modules/webhooks/repository_filter.go | 2 +- .../sms-gateway/modules/webhooks/service.go | 30 +-- internal/sms-gateway/online/metrics.go | 18 +- internal/sms-gateway/online/service.go | 16 +- internal/sms-gateway/openapi/docs.go | 104 ++++++++- internal/sms-gateway/pubsub/module.go | 6 +- internal/sms-gateway/pubsub/pubsub.go | 18 +- internal/version/version.go | 4 +- internal/worker/app.go | 4 +- internal/worker/config/config.go | 1 + internal/worker/config/types.go | 6 +- internal/worker/executor/metrics.go | 3 + internal/worker/executor/service.go | 12 +- internal/worker/locker/mysql.go | 9 +- 90 files changed, 1652 insertions(+), 1242 deletions(-) create mode 100644 internal/sms-gateway/cache/errors.go create mode 100644 internal/sms-gateway/modules/events/errors.go delete mode 100644 internal/sms-gateway/modules/messages/service_test.go create mode 100644 internal/sms-gateway/modules/push/client.go rename internal/sms-gateway/modules/push/{types => client}/types.go (55%) create mode 100644 internal/sms-gateway/modules/push/fcm/errors.go create mode 100644 internal/sms-gateway/modules/sse/errors.go diff --git a/api/requests.http b/api/requests.http index 75623362..b614a5b0 100644 --- a/api/requests.http +++ b/api/requests.http @@ -133,7 +133,6 @@ Content-Type: application/json { "id": "MYofX8bTd5Bov0wWFZLRP", - "deviceId": "C0ZGtCNf7-sXTbCtF6JXm", "url": "https://webhook.site/280a6655-eb68-40b9-b857-af5be37c5303", "event": "sms:received" } diff --git a/cmd/sms-gateway/main.go b/cmd/sms-gateway/main.go index 981695ce..ed896ee4 100644 --- a/cmd/sms-gateway/main.go +++ b/cmd/sms-gateway/main.go @@ -7,6 +7,10 @@ import ( "github.com/android-sms-gateway/server/internal/worker" ) +const ( + cmdWorker = "worker" +) + // @securitydefinitions.basic ApiAuth // @description User authentication @@ -36,15 +40,15 @@ import ( // @host api.sms-gate.app // @schemes https // -// SMSGate Backend +// SMSGate Backend. func main() { args := os.Args[1:] cmd := "start" - if len(args) > 0 && args[0] == "worker" { - cmd = "worker" + if len(args) > 0 && args[0] == cmdWorker { + cmd = cmdWorker } - if cmd == "worker" { + if cmd == cmdWorker { worker.Run() } else { smsgateway.Run() diff --git a/go.mod b/go.mod index 555dd3c1..b7ee483b 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,6 @@ require ( go.uber.org/fx v1.24.0 go.uber.org/zap v1.27.0 golang.org/x/crypto v0.45.0 - golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d google.golang.org/api v0.148.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde diff --git a/go.sum b/go.sum index e2189db2..0552a6f7 100644 --- a/go.sum +++ b/go.sum @@ -343,8 +343,6 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d h1:N0hmiNbwsSNwHBAvR3QB5w25pUwH4tK0Y/RltD1j1h4= -golang.org/x/exp v0.0.0-20240525044651-4c93da0ed11d/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= diff --git a/internal/config/config.go b/internal/config/config.go index c1513d37..656daa2a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -25,7 +25,7 @@ type Gateway struct { } type HTTP struct { - Listen string `yaml:"listen" envconfig:"HTTP__LISTEN"` // listen address + Listen string `yaml:"listen" envconfig:"HTTP__LISTEN"` // listen address Proxies []string `yaml:"proxies" envconfig:"HTTP__PROXIES"` // proxies API API `yaml:"api"` @@ -73,7 +73,7 @@ type SSE struct { } type Messages struct { - CacheTTLSeconds uint16 `yaml:"cache_ttl_seconds" envconfig:"MESSAGES__CACHE_TTL_SECONDS"` // cache ttl in seconds + CacheTTLSeconds uint16 `yaml:"cache_ttl_seconds" envconfig:"MESSAGES__CACHE_TTL_SECONDS"` // cache ttl in seconds HashingIntervalSeconds uint16 `yaml:"hashing_interval_seconds" envconfig:"MESSAGES__HASHING_INTERVAL_SECONDS"` // hashing interval in seconds } @@ -82,41 +82,46 @@ type Cache struct { } type PubSub struct { - URL string `yaml:"url" envconfig:"PUBSUB__URL"` + URL string `yaml:"url" envconfig:"PUBSUB__URL"` + BufferSize uint `yaml:"buffer_size" envconfig:"PUBSUB__BUFFER_SIZE"` } -var defaultConfig = Config{ - Gateway: Gateway{Mode: GatewayModePublic}, - HTTP: HTTP{ - Listen: ":3000", - }, - Database: Database{ - Host: "localhost", - Port: 3306, - User: "sms", - Password: "sms", - Database: "sms", - Timezone: "UTC", - Debug: false, - MaxOpenConns: 0, - MaxIdleConns: 0, - }, - FCM: FCMConfig{ - CredentialsJSON: "", - DebounceSeconds: 5, - TimeoutSeconds: 1, - }, - SSE: SSE{ - KeepAlivePeriodSeconds: 15, - }, - Messages: Messages{ - CacheTTLSeconds: 300, // 5 minutes - HashingIntervalSeconds: 60, - }, - Cache: Cache{ - URL: "memory://", - }, - PubSub: PubSub{ - URL: "memory://", - }, +func Default() Config { + //nolint:exhaustruct,mnd // default values + return Config{ + Gateway: Gateway{Mode: GatewayModePublic}, + HTTP: HTTP{ + Listen: ":3000", + }, + Database: Database{ + Host: "localhost", + Port: 3306, + User: "sms", + Password: "sms", + Database: "sms", + Timezone: "UTC", + }, + FCM: FCMConfig{ + CredentialsJSON: "", + }, + Tasks: Tasks{ + Hashing: HashingTask{ + IntervalSeconds: uint16(15 * 60), + }, + }, + SSE: SSE{ + KeepAlivePeriodSeconds: 15, + }, + Messages: Messages{ + CacheTTLSeconds: 300, // 5 minutes + HashingIntervalSeconds: 60, + }, + Cache: Cache{ + URL: "memory://", + }, + PubSub: PubSub{ + URL: "memory://", + BufferSize: 128, + }, + } } diff --git a/internal/config/module.go b/internal/config/module.go index 4b0927de..4d17c9bb 100644 --- a/internal/config/module.go +++ b/internal/config/module.go @@ -19,109 +19,122 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "appconfig", - fx.Provide( - func(log *zap.Logger) Config { - if err := config.LoadConfig(&defaultConfig); err != nil { - log.Error("Error loading config", zap.Error(err)) - } +//nolint:funlen // long function +func Module() fx.Option { + return fx.Module( + "appconfig", + fx.Provide( + func(log *zap.Logger) Config { + defaultConfig := Default() - return defaultConfig - }, - fx.Private, - ), - fx.Provide(func(cfg Config) http.Config { - return http.Config{ - Listen: cfg.HTTP.Listen, - Proxies: cfg.HTTP.Proxies, + if err := config.LoadConfig(&defaultConfig); err != nil { + log.Error("Error loading config", zap.Error(err)) + } - WriteTimeout: 30 * time.Minute, // SSE requires longer timeout - } - }), - fx.Provide(func(cfg Config) db.Config { - return db.Config{ - Dialect: db.DialectMySQL, - Host: cfg.Database.Host, - Port: cfg.Database.Port, - User: cfg.Database.User, - Password: cfg.Database.Password, - Database: cfg.Database.Database, - Timezone: cfg.Database.Timezone, - Debug: cfg.Database.Debug, + return defaultConfig + }, + fx.Private, + ), + fx.Provide(func(cfg Config) http.Config { + const writeTimeout = 30 * time.Minute - MaxOpenConns: cfg.Database.MaxOpenConns, - MaxIdleConns: cfg.Database.MaxIdleConns, - } - }), - fx.Provide(func(cfg Config) push.Config { - mode := push.ModeFCM - if cfg.Gateway.Mode == GatewayModePrivate { - mode = push.ModeUpstream - } + return http.Config{ + Listen: cfg.HTTP.Listen, + Proxies: cfg.HTTP.Proxies, - return push.Config{ - Mode: mode, - ClientOptions: map[string]string{ - "credentials": cfg.FCM.CredentialsJSON, - }, - Debounce: time.Duration(cfg.FCM.DebounceSeconds) * time.Second, - Timeout: time.Duration(cfg.FCM.TimeoutSeconds) * time.Second, - } - }), - fx.Provide(func(cfg Config) auth.Config { - return auth.Config{ - Mode: auth.Mode(cfg.Gateway.Mode), - PrivateToken: cfg.Gateway.PrivateToken, - } - }), - fx.Provide(func(cfg Config) handlers.Config { - // Default and normalize API path/host - if cfg.HTTP.API.Host == "" { - cfg.HTTP.API.Path = "/api" - } - // Ensure leading slash and trim trailing slash (except root) - if !strings.HasPrefix(cfg.HTTP.API.Path, "/") { - cfg.HTTP.API.Path = "/" + cfg.HTTP.API.Path - } - if cfg.HTTP.API.Path != "/" && strings.HasSuffix(cfg.HTTP.API.Path, "/") { - cfg.HTTP.API.Path = strings.TrimRight(cfg.HTTP.API.Path, "/") - } - // Guard against misconfigured scheme in host (accept "host[:port]" only) - cfg.HTTP.API.Host = strings.TrimPrefix(strings.TrimPrefix(cfg.HTTP.API.Host, "https://"), "http://") + WriteTimeout: writeTimeout, // SSE requires longer timeout + } + }), + fx.Provide(func(cfg Config) db.Config { + return db.Config{ + Dialect: db.DialectMySQL, + Host: cfg.Database.Host, + Port: cfg.Database.Port, + User: cfg.Database.User, + Password: cfg.Database.Password, + Database: cfg.Database.Database, + Timezone: cfg.Database.Timezone, + Debug: cfg.Database.Debug, - return handlers.Config{ - PublicHost: cfg.HTTP.API.Host, - PublicPath: cfg.HTTP.API.Path, - UpstreamEnabled: cfg.Gateway.Mode == GatewayModePublic, - OpenAPIEnabled: cfg.HTTP.OpenAPI.Enabled, - } - }), - fx.Provide(func(cfg Config) messages.Config { - return messages.Config{ - CacheTTL: time.Duration(cfg.Messages.CacheTTLSeconds) * time.Second, - HashingInterval: time.Duration(max(cfg.Tasks.Hashing.IntervalSeconds, cfg.Messages.HashingIntervalSeconds)) * time.Second, - } - }), - fx.Provide(func(cfg Config) devices.Config { - return devices.Config{ - UnusedLifetime: 365 * 24 * time.Hour, //TODO: make it configurable - } - }), - fx.Provide(func(cfg Config) sse.Config { - return sse.NewConfig( - sse.WithKeepAlivePeriod(time.Duration(cfg.SSE.KeepAlivePeriodSeconds) * time.Second), - ) - }), - fx.Provide(func(cfg Config) cache.Config { - return cache.Config{ - URL: cfg.Cache.URL, - } - }), - fx.Provide(func(cfg Config) pubsub.Config { - return pubsub.Config{ - URL: cfg.PubSub.URL, - BufferSize: 128, - } - }), -) + MaxOpenConns: cfg.Database.MaxOpenConns, + MaxIdleConns: cfg.Database.MaxIdleConns, + + DSN: "", + ConnMaxIdleTime: 0, + ConnMaxLifetime: 0, + } + }), + fx.Provide(func(cfg Config) push.Config { + mode := push.ModeFCM + if cfg.Gateway.Mode == GatewayModePrivate { + mode = push.ModeUpstream + } + + return push.Config{ + Mode: mode, + ClientOptions: map[string]string{ + "credentials": cfg.FCM.CredentialsJSON, + }, + Debounce: time.Duration(cfg.FCM.DebounceSeconds) * time.Second, + Timeout: time.Duration(cfg.FCM.TimeoutSeconds) * time.Second, + } + }), + fx.Provide(func(cfg Config) auth.Config { + return auth.Config{ + Mode: auth.Mode(cfg.Gateway.Mode), + PrivateToken: cfg.Gateway.PrivateToken, + } + }), + fx.Provide(func(cfg Config) handlers.Config { + // Default and normalize API path/host + if cfg.HTTP.API.Host == "" { + cfg.HTTP.API.Path = "/api" + } + // Ensure leading slash and trim trailing slash (except root) + if !strings.HasPrefix(cfg.HTTP.API.Path, "/") { + cfg.HTTP.API.Path = "/" + cfg.HTTP.API.Path + } + if cfg.HTTP.API.Path != "/" && strings.HasSuffix(cfg.HTTP.API.Path, "/") { + cfg.HTTP.API.Path = strings.TrimRight(cfg.HTTP.API.Path, "/") + } + // Guard against misconfigured scheme in host (accept "host[:port]" only) + cfg.HTTP.API.Host = strings.TrimPrefix(strings.TrimPrefix(cfg.HTTP.API.Host, "https://"), "http://") + + return handlers.Config{ + PublicHost: cfg.HTTP.API.Host, + PublicPath: cfg.HTTP.API.Path, + UpstreamEnabled: cfg.Gateway.Mode == GatewayModePublic, + OpenAPIEnabled: cfg.HTTP.OpenAPI.Enabled, + } + }), + fx.Provide(func(cfg Config) messages.Config { + return messages.Config{ + CacheTTL: time.Duration(cfg.Messages.CacheTTLSeconds) * time.Second, + HashingInterval: time.Duration( + max(cfg.Tasks.Hashing.IntervalSeconds, cfg.Messages.HashingIntervalSeconds), + ) * time.Second, + } + }), + fx.Provide(func(_ Config) devices.Config { + return devices.Config{ + UnusedLifetime: 365 * 24 * time.Hour, //TODO: make it configurable + } + }), + fx.Provide(func(cfg Config) sse.Config { + return sse.NewConfig( + sse.WithKeepAlivePeriod(time.Duration(cfg.SSE.KeepAlivePeriodSeconds) * time.Second), + ) + }), + fx.Provide(func(cfg Config) cache.Config { + return cache.Config{ + URL: cfg.Cache.URL, + } + }), + fx.Provide(func(cfg Config) pubsub.Config { + return pubsub.Config{ + URL: cfg.PubSub.URL, + BufferSize: cfg.PubSub.BufferSize, + } + }), + ) +} diff --git a/internal/sms-gateway/app.go b/internal/sms-gateway/app.go index be827594..dc649939 100644 --- a/internal/sms-gateway/app.go +++ b/internal/sms-gateway/app.go @@ -30,36 +30,38 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "server", - logger.Module(), - appconfig.Module, - appdb.Module(), - http.Module, - validator.Module, - openapi.Module(), - handlers.Module, - auth.Module, - push.Module(), - db.Module, - cache.Module(), - pubsub.Module(), - events.Module, - messages.Module(), - health.Module(), - webhooks.Module, - settings.Module, - devices.Module(), - metrics.Module, - sse.Module, - online.Module(), -) +func Module() fx.Option { + return fx.Module( + "server", + logger.Module(), + appconfig.Module(), + appdb.Module(), + http.Module, + validator.Module, + openapi.Module(), + handlers.Module(), + auth.Module(), + push.Module(), + db.Module, + cache.Module(), + pubsub.Module(), + events.Module(), + messages.Module(), + health.Module(), + webhooks.Module(), + settings.Module(), + devices.Module(), + metrics.Module(), + sse.Module(), + online.Module(), + ) +} func Run() { - cli.DefaultCommand = "start" + cli.DefaultCommand = "start" //nolint:reassign //framework specific fx.New( cli.GetModule(), - Module, + Module(), logger.WithFxDefaultLogger(), ).Run() } @@ -116,6 +118,7 @@ func Start(p StartParams) error { return nil } +//nolint:gochecknoinits //backward compatibility func init() { cli.Register("start", Start) } diff --git a/internal/sms-gateway/cache/errors.go b/internal/sms-gateway/cache/errors.go new file mode 100644 index 00000000..20f44a96 --- /dev/null +++ b/internal/sms-gateway/cache/errors.go @@ -0,0 +1,7 @@ +package cache + +import "errors" + +var ( + ErrInvalidConfig = errors.New("invalid config") +) diff --git a/internal/sms-gateway/cache/factory.go b/internal/sms-gateway/cache/factory.go index 16659d89..2c6ee46f 100644 --- a/internal/sms-gateway/cache/factory.go +++ b/internal/sms-gateway/cache/factory.go @@ -28,13 +28,13 @@ func NewFactory(config Config) (Factory, error) { u, err := url.Parse(config.URL) if err != nil { - return nil, fmt.Errorf("can't parse url: %w", err) + return nil, fmt.Errorf("%w: failed to parse url: %w", ErrInvalidConfig, err) } switch u.Scheme { case "memory": return &factory{ - new: func(name string) (Cache, error) { + new: func(_ string) (Cache, error) { return cache.NewMemory(0), nil }, }, nil @@ -50,7 +50,7 @@ func NewFactory(config Config) (Factory, error) { }, }, nil default: - return nil, fmt.Errorf("invalid scheme: %s", u.Scheme) + return nil, fmt.Errorf("%w: invalid scheme: %s", ErrInvalidConfig, u.Scheme) } } diff --git a/internal/sms-gateway/handlers/base/handler.go b/internal/sms-gateway/handlers/base/handler.go index aaa8b1a5..6b574611 100644 --- a/internal/sms-gateway/handlers/base/handler.go +++ b/internal/sms-gateway/handlers/base/handler.go @@ -19,7 +19,7 @@ type Handler struct { func (h *Handler) BodyParserValidator(c *fiber.Ctx, out any) error { if err := c.BodyParser(out); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Can't parse body: %s", err.Error())) + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse body: %s", err.Error())) } return h.ValidateStruct(out) @@ -27,7 +27,7 @@ func (h *Handler) BodyParserValidator(c *fiber.Ctx, out any) error { func (h *Handler) QueryParserValidator(c *fiber.Ctx, out any) error { if err := c.QueryParser(out); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Can't parse query: %s", err.Error())) + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse query: %s", err.Error())) } return h.ValidateStruct(out) @@ -35,7 +35,7 @@ func (h *Handler) QueryParserValidator(c *fiber.Ctx, out any) error { func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error { if err := c.ParamsParser(out); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Can't parse params: %s", err.Error())) + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse params: %s", err.Error())) } return h.ValidateStruct(out) diff --git a/internal/sms-gateway/handlers/base/handler_test.go b/internal/sms-gateway/handlers/base/handler_test.go index b1f8f5a7..6b1ebc98 100644 --- a/internal/sms-gateway/handlers/base/handler_test.go +++ b/internal/sms-gateway/handlers/base/handler_test.go @@ -3,7 +3,7 @@ package base_test import ( "bytes" "encoding/json" - "fmt" + "errors" "net/http" "net/http/httptest" "testing" @@ -16,41 +16,41 @@ import ( type testRequestBody struct { Name string `json:"name" validate:"required"` - Age int `json:"age" validate:"required"` + Age int `json:"age" validate:"required"` } type testRequestBodyNoValidate struct { Name string `json:"name" validate:"required"` - Age int `json:"age" validate:"required"` + Age int `json:"age" validate:"required"` } type testRequestQuery struct { Name string `query:"name" validate:"required"` - Age int `query:"age" validate:"required"` + Age int `query:"age" validate:"required"` } type testRequestParams struct { - ID string `params:"id" validate:"required"` + ID string `params:"id" validate:"required"` Name string `params:"name" validate:"required"` } func (t *testRequestBody) Validate() error { if t.Age < 18 { - return fmt.Errorf("must be at least 18 years old") + return errors.New("must be at least 18 years old") } return nil } func (t *testRequestQuery) Validate() error { if t.Age < 18 { - return fmt.Errorf("must be at least 18 years old") + return errors.New("must be at least 18 years old") } return nil } func (t *testRequestParams) Validate() error { if t.ID == "invalid" { - return fmt.Errorf("invalid ID") + return errors.New("invalid ID") } return nil } @@ -117,10 +117,10 @@ func TestHandler_BodyParserValidator(t *testing.T) { var req *http.Request if test.payload != nil { bodyBytes, _ := json.Marshal(test.payload) - req = httptest.NewRequest("POST", test.path, bytes.NewReader(bodyBytes)) + req = httptest.NewRequest(http.MethodPost, test.path, bytes.NewReader(bodyBytes)) req.Header.Set("Content-Type", "application/json") } else { - req = httptest.NewRequest("POST", test.path, nil) + req = httptest.NewRequest(http.MethodPost, test.path, nil) } resp, err := app.Test(req) @@ -183,7 +183,7 @@ func TestHandler_QueryParserValidator(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - req := httptest.NewRequest("GET", test.path, nil) + req := httptest.NewRequest(http.MethodGet, test.path, nil) resp, err := app.Test(req) if err != nil { @@ -240,7 +240,7 @@ func TestHandler_ParamsParserValidator(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - req := httptest.NewRequest("GET", test.path, nil) + req := httptest.NewRequest(http.MethodGet, test.path, nil) resp, err := app.Test(req) if err != nil { diff --git a/internal/sms-gateway/handlers/converters/messages.go b/internal/sms-gateway/handlers/converters/messages.go index 780115a4..95ce65dd 100644 --- a/internal/sms-gateway/handlers/converters/messages.go +++ b/internal/sms-gateway/handlers/converters/messages.go @@ -24,7 +24,8 @@ func MessageToMobileDTO(m messages.MessageOut) smsgateway.MobileMessage { return smsgateway.MobileMessage{ Message: smsgateway.Message{ - ID: m.ID, + ID: m.ID, + DeviceID: "", Message: message, TextMessage: textMessage, diff --git a/internal/sms-gateway/handlers/devices/3rdparty.go b/internal/sms-gateway/handlers/devices/3rdparty.go index 8af72887..e44491e7 100644 --- a/internal/sms-gateway/handlers/devices/3rdparty.go +++ b/internal/sms-gateway/handlers/devices/3rdparty.go @@ -10,25 +10,31 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/capcom6/go-helpers/slices" + "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" - "go.uber.org/fx" "go.uber.org/zap" ) -type thirdPartyControllerParams struct { - fx.In - - DevicesSvc *devices.Service - - Logger *zap.Logger -} - type ThirdPartyController struct { base.Handler devicesSvc *devices.Service } +func NewThirdPartyController( + devicesSvc *devices.Service, + logger *zap.Logger, + validator *validator.Validate, +) *ThirdPartyController { + return &ThirdPartyController{ + Handler: base.Handler{ + Logger: logger, + Validator: validator, + }, + devicesSvc: devicesSvc, + } +} + // @Summary List devices // @Description Returns list of registered devices // @Security ApiAuth @@ -40,11 +46,11 @@ type ThirdPartyController struct { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/devices [get] // -// List devices +// List devices. func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { devices, err := h.devicesSvc.Select(user.ID) if err != nil { - return fmt.Errorf("can't select devices: %w", err) + return fmt.Errorf("failed to select devices: %w", err) } response := slices.Map(devices, converters.DeviceToDTO) @@ -65,7 +71,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/devices/{id} [delete] // -// Remove device +// Remove device. func (h *ThirdPartyController) remove(user models.User, c *fiber.Ctx) error { id := c.Params("id") @@ -74,7 +80,7 @@ func (h *ThirdPartyController) remove(user models.User, c *fiber.Ctx) error { return fiber.NewError(fiber.StatusNotFound, err.Error()) } if err != nil { - return fmt.Errorf("can't remove device: %w", err) + return fmt.Errorf("failed to remove device: %w", err) } return c.SendStatus(fiber.StatusNoContent) @@ -84,12 +90,3 @@ func (h *ThirdPartyController) Register(router fiber.Router) { router.Get("", userauth.WithUser(h.get)) router.Delete(":id", userauth.WithUser(h.remove)) } - -func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { - return &ThirdPartyController{ - Handler: base.Handler{ - Logger: params.Logger.Named("devices"), - }, - devicesSvc: params.DevicesSvc, - } -} diff --git a/internal/sms-gateway/handlers/events/mobile.go b/internal/sms-gateway/handlers/events/mobile.go index 3df1661b..8cbb6a04 100644 --- a/internal/sms-gateway/handlers/events/mobile.go +++ b/internal/sms-gateway/handlers/events/mobile.go @@ -41,9 +41,9 @@ func NewMobileController(sseService *sse.Service, validator *validator.Validate, // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/events [get] // -// Get events +// Get events. func (h *MobileController) get(device models.Device, c *fiber.Ctx) error { - return h.sseSvc.Handler(device.ID, c) + return h.sseSvc.Handler(device.ID, c) //nolint:wrapcheck //wrapped internally } func (h *MobileController) Register(router fiber.Router) { diff --git a/internal/sms-gateway/handlers/health.go b/internal/sms-gateway/handlers/health.go index bd309de0..60be5853 100644 --- a/internal/sms-gateway/handlers/health.go +++ b/internal/sms-gateway/handlers/health.go @@ -37,34 +37,34 @@ func NewHealthHandler( // @Failure 503 {object} smsgateway.HealthResponse "Service is not alive" // @Router /health/live [get] // -// Liveness probe +// Liveness probe. func (h *HealthHandler) getLiveness(c *fiber.Ctx) error { return writeProbe(c, h.healthSvc.CheckLiveness(c.Context())) } -// @Summary Readiness probe -// @Description Checks if service is ready to serve traffic (readiness probe) -// @Tags System -// @Produce json -// @Success 200 {object} smsgateway.HealthResponse "Service is ready" -// @Failure 503 {object} smsgateway.HealthResponse "Service is not ready" -// @Router /health/ready [get] -// @Router /3rdparty/v1/health [get] +// @Summary Readiness probe +// @Description Checks if service is ready to serve traffic (readiness probe) +// @Tags System +// @Produce json +// @Success 200 {object} smsgateway.HealthResponse "Service is ready" +// @Failure 503 {object} smsgateway.HealthResponse "Service is not ready" +// @Router /health/ready [get] +// @Router /3rdparty/v1/health [get] // -// Readiness probe +// Readiness probe. func (h *HealthHandler) getReadiness(c *fiber.Ctx) error { return writeProbe(c, h.healthSvc.CheckReadiness(c.Context())) } -// @Summary Startup probe -// @Description Checks if service has completed initialization (startup probe) -// @Tags System -// @Produce json -// @Success 200 {object} smsgateway.HealthResponse "Service has completed initialization" -// @Failure 503 {object} smsgateway.HealthResponse "Service has not completed initialization" -// @Router /health/startup [get] +// @Summary Startup probe +// @Description Checks if service has completed initialization (startup probe) +// @Tags System +// @Produce json +// @Success 200 {object} smsgateway.HealthResponse "Service has completed initialization" +// @Failure 503 {object} smsgateway.HealthResponse "Service has not completed initialization" +// @Router /health/startup [get] // -// Startup probe +// Startup probe. func (h *HealthHandler) getStartup(c *fiber.Ctx) error { return writeProbe(c, h.healthSvc.CheckStartup(c.Context())) } @@ -84,7 +84,7 @@ func makeResponse(result health.CheckResult) smsgateway.HealthResponse { ReleaseID: version.AppReleaseID(), Checks: lo.MapValues( result.Checks, - func(value health.CheckDetail, key string) smsgateway.HealthCheck { + func(value health.CheckDetail, _ string) smsgateway.HealthCheck { return smsgateway.HealthCheck{ Description: value.Description, ObservedUnit: value.ObservedUnit, diff --git a/internal/sms-gateway/handlers/logs/3rdparty.go b/internal/sms-gateway/handlers/logs/3rdparty.go index 11ffe62c..16ec88fd 100644 --- a/internal/sms-gateway/handlers/logs/3rdparty.go +++ b/internal/sms-gateway/handlers/logs/3rdparty.go @@ -21,6 +21,15 @@ type ThirdPartyController struct { base.Handler } +func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { + return &ThirdPartyController{ + Handler: base.Handler{ + Logger: params.Logger, + Validator: params.Validator, + }, + } +} + // @Summary Get logs // @Description Retrieve a list of log entries within a specified time range. // @Security ApiAuth @@ -34,20 +43,14 @@ type ThirdPartyController struct { // @Failure 501 {object} smsgateway.ErrorResponse "Not implemented" // @Router /3rdparty/v1/logs [get] // -// List webhooks -func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { - return fiber.NewError(fiber.StatusNotImplemented, "For privacy reasons, device's logs are not accessible through Cloud server") +// Get logs. +func (h *ThirdPartyController) get(_ models.User, _ *fiber.Ctx) error { + return fiber.NewError( + fiber.StatusNotImplemented, + "For privacy reasons, device's logs are not accessible through Cloud server", + ) } func (h *ThirdPartyController) Register(router fiber.Router) { router.Get("", userauth.WithUser(h.get)) } - -func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { - return &ThirdPartyController{ - Handler: base.Handler{ - Logger: params.Logger.Named("logs"), - Validator: params.Validator, - }, - } -} diff --git a/internal/sms-gateway/handlers/messages/3rdparty.go b/internal/sms-gateway/handlers/messages/3rdparty.go index 5fe0851a..0f609b4d 100644 --- a/internal/sms-gateway/handlers/messages/3rdparty.go +++ b/internal/sms-gateway/handlers/messages/3rdparty.go @@ -41,6 +41,17 @@ type ThirdPartyController struct { devicesSvc *devices.Service } +func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { + return &ThirdPartyController{ + Handler: base.Handler{ + Logger: params.Logger, + Validator: params.Validator, + }, + messagesSvc: params.MessagesSvc, + devicesSvc: params.DevicesSvc, + } +} + // @Summary Enqueue message // @Description Enqueues a message for sending. If `deviceId` is set, the specified device is used; otherwise a random registered device is chosen. // @Security ApiAuth @@ -58,53 +69,32 @@ type ThirdPartyController struct { // @Header 202 {string} Location "Get message state URL" // @Router /3rdparty/v1/messages [post] // -// Enqueue message +// Enqueue message. func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { var params thirdPartyPostQueryParams if err := h.QueryParserValidator(c, ¶ms); err != nil { - return err + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } var req smsgateway.Message if err := h.BodyParserValidator(c, &req); err != nil { - return err - } - - var device models.Device - var err error - var filters []devices.SelectFilter - - if params.DeviceActiveWithin > 0 { - filters = append(filters, devices.ActiveWithin(time.Duration(params.DeviceActiveWithin)*time.Hour)) + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } - // Check if device_id is provided - if req.DeviceID != "" { - - device, err = h.devicesSvc.Get(user.ID, append(filters, devices.WithID(req.DeviceID))...) - if err != nil { - if errors.Is(err, devices.ErrNotFound) { - return fiber.NewError(fiber.StatusBadRequest, "No active device with such ID found") - } - h.Logger.Error("Failed to get device", zap.Error(err), zap.String("user_id", user.ID), zap.String("device_id", req.DeviceID)) - return fiber.NewError(fiber.StatusInternalServerError, "Can't select device. Please contact support") - } - } else { - // Fallback to random selection - devices, err := h.devicesSvc.Select(user.ID, filters...) - if err != nil { - h.Logger.Error("Failed to select devices", zap.Error(err), zap.String("user_id", user.ID)) - return fiber.NewError(fiber.StatusInternalServerError, "Can't select devices. Please contact support") - } - - if len(devices) < 1 { - return fiber.NewError(fiber.StatusBadRequest, "No active devices found") - } - - device, err = slices.Random(devices) - if err != nil { - return fmt.Errorf("can't get random device: %w", err) - } + device, err := h.devicesSvc.GetAny( + user.ID, + req.DeviceID, + time.Duration(params.DeviceActiveWithin)*time.Hour, + ) + if err != nil { + h.Logger.Error( + "failed to select device", + zap.Error(err), + zap.String("user_id", user.ID), + zap.String("device_id", req.DeviceID), + ) + + return fmt.Errorf("failed to select device: %w", err) } var textContent *messages.TextMessageContent @@ -137,24 +127,32 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { ValidUntil: req.ValidUntil, Priority: req.Priority, } - state, err := h.messagesSvc.Enqueue(device, msg, messages.EnqueueOptions{SkipPhoneValidation: params.SkipPhoneValidation}) + state, err := h.messagesSvc.Enqueue( + *device, + msg, + messages.EnqueueOptions{SkipPhoneValidation: params.SkipPhoneValidation}, + ) if err != nil { - var errValidation messages.ErrValidation - if isBadRequest := errors.As(err, &errValidation); isBadRequest { - return fiber.NewError(fiber.StatusBadRequest, errValidation.Error()) - } - if isConflict := errors.Is(err, messages.ErrMessageAlreadyExists); isConflict { - return fiber.NewError(fiber.StatusConflict, err.Error()) - } - - return fmt.Errorf("can't enqueue message: %w", err) + h.Logger.Error( + "failed to enqueue message", + zap.Error(err), + zap.String("user_id", user.ID), + zap.String("device_id", req.DeviceID), + ) + + return fmt.Errorf("failed to enqueue message: %w", err) } location, err := c.GetRouteURL(route3rdPartyGetMessage, fiber.Map{ "id": state.ID, }) if err != nil { - h.Logger.Warn("Failed to get route URL", zap.String("route", route3rdPartyGetMessage), zap.Error(err)) + h.Logger.Warn( + "failed to get route URL", + zap.String("route", route3rdPartyGetMessage), + zap.String("id", state.ID), + zap.Error(err), + ) } else { c.Location(location) } @@ -188,17 +186,17 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/messages [get] // -// Get message history +// Get message history. func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error { - params := thirdPartyGetQueryParams{} - if err := h.QueryParserValidator(c, ¶ms); err != nil { - return err + params := new(thirdPartyGetQueryParams) + if err := h.QueryParserValidator(c, params); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } messages, total, err := h.messagesSvc.SelectStates(user, params.ToFilter(), params.ToOptions()) if err != nil { - h.Logger.Error("Failed to get message history", zap.Error(err), zap.String("user_id", user.ID)) - return fiber.NewError(fiber.StatusInternalServerError, "Failed to retrieve message history") + h.Logger.Error("failed to get message history", zap.Error(err), zap.String("user_id", user.ID)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to retrieve message history") } c.Set("X-Total-Count", strconv.Itoa(int(total))) @@ -219,7 +217,7 @@ func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/messages/{id} [get] // -// Get message state +// Get message state. func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { id := c.Params("id") @@ -229,7 +227,8 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { return fiber.NewError(fiber.StatusNotFound, err.Error()) } - return err + h.Logger.Error("failed to get message state", zap.Error(err), zap.String("user_id", user.ID)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to get message state") } return c.JSON(converters.MessageStateToDTO(*state)) @@ -248,11 +247,11 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/messages/inbox/export [post] // -// Export inbox +// Export inbox. func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error { - req := smsgateway.MessagesExportRequest{} - if err := h.BodyParserValidator(c, &req); err != nil { - return err + req := new(smsgateway.MessagesExportRequest) + if err := h.BodyParserValidator(c, req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } device, err := h.devicesSvc.Get(user.ID, devices.WithID(req.DeviceID)) @@ -261,31 +260,64 @@ func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) e return fiber.NewError(fiber.StatusBadRequest, "Invalid device ID") } - return err + h.Logger.Error("failed to get device", zap.Error(err), zap.String("user_id", user.ID)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to get device") } - if err := h.messagesSvc.ExportInbox(device, req.Since, req.Until); err != nil { - return err + if expErr := h.messagesSvc.ExportInbox(device, req.Since, req.Until); expErr != nil { + h.Logger.Error("failed to export inbox", zap.Error(expErr), zap.String("user_id", user.ID)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to export inbox") } return c.SendStatus(fiber.StatusAccepted) } +func (h *ThirdPartyController) errorHandler(c *fiber.Ctx) error { + err := c.Next() + if err == nil { + return nil + } + + var fiberError *fiber.Error + if errors.As(err, &fiberError) { + return fiberError + } + + var msgValidationError messages.ValidationError + switch { + case errors.As(err, &msgValidationError): + fallthrough + case errors.Is(err, messages.ErrMultipleMessagesFound): + fallthrough + case errors.Is(err, messages.ErrNoContent): + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + + case errors.Is(err, messages.ErrMessageNotFound): + return fiber.NewError(fiber.StatusNotFound, err.Error()) + + case errors.Is(err, messages.ErrMessageAlreadyExists): + return fiber.NewError(fiber.StatusConflict, err.Error()) + + case errors.Is(err, devices.ErrNotFound): + fallthrough + case errors.Is(err, devices.ErrInvalidFilter): + fallthrough + case errors.Is(err, devices.ErrInvalidUser): + fallthrough + case errors.Is(err, devices.ErrMoreThanOne): + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + } + + h.Logger.Error("failed to handle request", zap.Error(err)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to handle request") +} + func (h *ThirdPartyController) Register(router fiber.Router) { + router.Use(h.errorHandler) + router.Get("", userauth.WithUser(h.list)) router.Post("", userauth.WithUser(h.post)) router.Get(":id", userauth.WithUser(h.get)).Name(route3rdPartyGetMessage) router.Post("inbox/export", userauth.WithUser(h.postInboxExport)) } - -func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { - return &ThirdPartyController{ - Handler: base.Handler{ - Logger: params.Logger.Named("messages"), - Validator: params.Validator, - }, - messagesSvc: params.MessagesSvc, - devicesSvc: params.DevicesSvc, - } -} diff --git a/internal/sms-gateway/handlers/messages/mobile.go b/internal/sms-gateway/handlers/messages/mobile.go index 173bc772..d6f174a6 100644 --- a/internal/sms-gateway/handlers/messages/mobile.go +++ b/internal/sms-gateway/handlers/messages/mobile.go @@ -32,29 +32,39 @@ type MobileController struct { messagesSvc *messages.Service } +func NewMobileController(params mobileControllerParams) *MobileController { + return &MobileController{ + Handler: base.Handler{ + Logger: params.Logger, + Validator: params.Validator, + }, + messagesSvc: params.MessagesSvc, + } +} + // @Summary Get messages for sending // @Description Returns list of pending messages // @Security MobileToken // @Tags Device, Messages // @Accept json // @Produce json -// @Param order query string false "Message processing order: lifo (default) or fifo" Enums(lifo,fifo) default(lifo) +// @Param order query string false "Message processing order: lifo (default) or fifo" Enums(lifo,fifo) default(lifo) // @Success 200 {object} smsgateway.MobileGetMessagesResponse "List of pending messages" // @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/message [get] // -// Get messages for sending +// Get messages for sending. func (h *MobileController) list(device models.Device, c *fiber.Ctx) error { // Get and validate order parameter - params := mobileGetQueryParams{} - if err := h.QueryParserValidator(c, ¶ms); err != nil { - return err + params := new(mobileGetQueryParams) + if err := h.QueryParserValidator(c, params); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } msgs, err := h.messagesSvc.SelectPending(device.ID, params.OrderOrDefault()) if err != nil { - return fmt.Errorf("can't get messages: %w", err) + return fmt.Errorf("failed to get messages: %w", err) } return c.JSON( @@ -79,11 +89,11 @@ func (h *MobileController) list(device models.Device, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/message [patch] // -// Update message state +// Update message state. func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error { req := smsgateway.MobilePatchMessageRequest{} if err := h.BodyParserValidator(c, &req); err != nil { - return err + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } for _, v := range req { @@ -96,7 +106,7 @@ func (h *MobileController) patch(device models.Device, c *fiber.Ctx) error { err := h.messagesSvc.UpdateState(&device, messageState) if err != nil && !errors.Is(err, messages.ErrMessageNotFound) { - h.Logger.Error("Can't update message status", + h.Logger.Error("failed to update message status", zap.String("message_id", v.ID), zap.Error(err), ) @@ -110,13 +120,3 @@ func (h *MobileController) Register(router fiber.Router) { router.Get("", deviceauth.WithDevice(h.list)) router.Patch("", deviceauth.WithDevice(h.patch)) } - -func NewMobileController(params mobileControllerParams) *MobileController { - return &MobileController{ - Handler: base.Handler{ - Logger: params.Logger.Named("messages"), - Validator: params.Validator, - }, - messagesSvc: params.MessagesSvc, - } -} diff --git a/internal/sms-gateway/handlers/messages/params.go b/internal/sms-gateway/handlers/messages/params.go index b98de6aa..9a16c028 100644 --- a/internal/sms-gateway/handlers/messages/params.go +++ b/internal/sms-gateway/handlers/messages/params.go @@ -1,7 +1,7 @@ package messages import ( - "fmt" + "errors" "time" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/messages" @@ -9,28 +9,28 @@ import ( type thirdPartyPostQueryParams struct { SkipPhoneValidation bool `query:"skipPhoneValidation"` - DeviceActiveWithin uint `query:"deviceActiveWithin"` + DeviceActiveWithin int `query:"deviceActiveWithin" validate:"omitempty,min=1"` } type thirdPartyGetQueryParams struct { - StartDate string `query:"from" validate:"omitempty,datetime=2006-01-02T15:04:05Z07:00"` - EndDate string `query:"to" validate:"omitempty,datetime=2006-01-02T15:04:05Z07:00"` - State string `query:"state" validate:"omitempty,oneof=Pending Processed Sent Delivered Failed"` + StartDate string `query:"from" validate:"omitempty,datetime=2006-01-02T15:04:05Z07:00"` + EndDate string `query:"to" validate:"omitempty,datetime=2006-01-02T15:04:05Z07:00"` + State string `query:"state" validate:"omitempty,oneof=Pending Processed Sent Delivered Failed"` DeviceID string `query:"deviceId" validate:"omitempty,len=21"` - Limit int `query:"limit" validate:"omitempty,min=1,max=100"` - Offset int `query:"offset" validate:"omitempty,min=0"` + Limit int `query:"limit" validate:"omitempty,min=1,max=100"` + Offset int `query:"offset" validate:"omitempty,min=0"` } func (p *thirdPartyGetQueryParams) Validate() error { if p.StartDate != "" && p.EndDate != "" && p.StartDate > p.EndDate { - return fmt.Errorf("`from` date must be before `to` date") + return errors.New("`from` date must be before `to` date") //nolint:err113 // won't be used directly } return nil } -func (p *thirdPartyGetQueryParams) ToFilter() messages.MessagesSelectFilter { - filter := messages.MessagesSelectFilter{} +func (p *thirdPartyGetQueryParams) ToFilter() messages.SelectFilter { + var filter messages.SelectFilter if p.StartDate != "" { if t, err := time.Parse(time.RFC3339, p.StartDate); err == nil { @@ -55,14 +55,15 @@ func (p *thirdPartyGetQueryParams) ToFilter() messages.MessagesSelectFilter { return filter } -func (p *thirdPartyGetQueryParams) ToOptions() messages.MessagesSelectOptions { - options := messages.MessagesSelectOptions{ - WithRecipients: true, - WithStates: true, - } +func (p *thirdPartyGetQueryParams) ToOptions() messages.SelectOptions { + const maxLimit = 100 + + var options messages.SelectOptions + options.WithRecipients = true + options.WithStates = true if p.Limit > 0 { - options.Limit = min(p.Limit, 100) + options.Limit = min(p.Limit, maxLimit) } else { options.Limit = 50 } @@ -75,13 +76,12 @@ func (p *thirdPartyGetQueryParams) ToOptions() messages.MessagesSelectOptions { } type mobileGetQueryParams struct { - Order messages.MessagesOrder `query:"order" validate:"omitempty,oneof=lifo fifo"` + Order messages.Order `query:"order" validate:"omitempty,oneof=lifo fifo"` } -func (p *mobileGetQueryParams) OrderOrDefault() messages.MessagesOrder { +func (p *mobileGetQueryParams) OrderOrDefault() messages.Order { if p.Order != "" { return p.Order } return messages.MessagesOrderLIFO - } diff --git a/internal/sms-gateway/handlers/middlewares/userauth/userauth.go b/internal/sms-gateway/handlers/middlewares/userauth/userauth.go index 6b88fbdb..0b69e3a9 100644 --- a/internal/sms-gateway/handlers/middlewares/userauth/userauth.go +++ b/internal/sms-gateway/handlers/middlewares/userauth/userauth.go @@ -87,16 +87,16 @@ func NewCode(authSvc *auth.Service) fiber.Handler { // It returns true if the Locals contain a user under the key LocalsUser, // otherwise returns false. func HasUser(c *fiber.Ctx) bool { - return c.Locals(localsUser) != nil + return GetUser(c) != nil } // GetUser returns the user stored in the Locals under the key LocalsUser. -// It is a convenience function that wraps the call to c.Locals(LocalsUser) and -// casts the result to models.User. -// -// It panics if the value stored in Locals is not a models.User. -func GetUser(c *fiber.Ctx) models.User { - return c.Locals(localsUser).(models.User) +func GetUser(c *fiber.Ctx) *models.User { + if user, ok := c.Locals(localsUser).(*models.User); ok { + return user + } + + return nil } // UserRequired is a middleware that ensures a user is present in the request's Locals. @@ -113,13 +113,13 @@ func UserRequired() fiber.Handler { } // WithUser is a decorator that provides the current user to the handler. -// It assumes that the user is stored in the Locals under the key LocalsUser. -// If the user is not present, it will panic. -// -// It is a convenience function that wraps the call to GetUser and calls the -// handler with the user as the first argument. func WithUser(handler func(models.User, *fiber.Ctx) error) fiber.Handler { return func(c *fiber.Ctx) error { - return handler(GetUser(c), c) + user := GetUser(c) + if user == nil { + return fiber.NewError(fiber.StatusUnauthorized, "Unauthorized") + } + + return handler(*user, c) } } diff --git a/internal/sms-gateway/handlers/mobile.go b/internal/sms-gateway/handlers/mobile.go index d9f04fc5..1f53bf21 100644 --- a/internal/sms-gateway/handlers/mobile.go +++ b/internal/sms-gateway/handlers/mobile.go @@ -62,10 +62,11 @@ type mobileHandler struct { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/device [get] // -// Get device information +// Get device information. func (h *mobileHandler) getDevice(device models.Device, c *fiber.Ctx) error { res := smsgateway.MobileDeviceResponse{ ExternalIP: c.IP(), + Device: nil, } if !device.IsEmpty() { @@ -91,16 +92,17 @@ func (h *mobileHandler) getDevice(device models.Device, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/device [post] // -// Register device -func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) { - req := smsgateway.MobileRegisterRequest{} +// Register device. +func (h *mobileHandler) postDevice(c *fiber.Ctx) error { + req := new(smsgateway.MobileRegisterRequest) - if err = h.BodyParserValidator(c, &req); err != nil { - return err + if err := h.BodyParserValidator(c, req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } var ( - user models.User + err error + user *models.User login string password string ) @@ -115,13 +117,13 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) { user, err = h.authSvc.RegisterUser(login, password) if err != nil { - return fmt.Errorf("can't create user: %w", err) + return fmt.Errorf("failed to create user: %w", err) } } device, err := h.authSvc.RegisterDevice(user, req.Name, req.PushToken) if err != nil { - return fmt.Errorf("can't register device: %w", err) + return fmt.Errorf("failed to register device: %w", err) } return c.Status(fiber.StatusCreated). @@ -145,12 +147,12 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) (err error) { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/device [patch] // -// Update device +// Update device. func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error { - req := smsgateway.MobileUpdateRequest{} + req := new(smsgateway.MobileUpdateRequest) - if err := h.BodyParserValidator(c, &req); err != nil { - return err + if err := h.BodyParserValidator(c, req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } if req.Id != device.ID { @@ -158,7 +160,8 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error { } if err := h.devicesSvc.UpdatePushToken(req.Id, req.PushToken); err != nil { - return err + h.Logger.Error("failed to update device", zap.Error(err), zap.String("device_id", req.Id)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to update device") } return c.SendStatus(fiber.StatusNoContent) @@ -174,11 +177,12 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/user/code [get] // -// Get user code +// Get user code. func (h *mobileHandler) getUserCode(user models.User, c *fiber.Ctx) error { code, err := h.authSvc.GenerateUserCode(user.ID) if err != nil { - return err + h.Logger.Error("failed to generate user code", zap.Error(err), zap.String("user_id", user.ID)) + return fiber.NewError(fiber.StatusInternalServerError, "failed to generate user code") } return c.JSON(smsgateway.MobileUserCodeResponse{ @@ -200,12 +204,12 @@ func (h *mobileHandler) getUserCode(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/user/password [patch] // -// Change password +// Change password. func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error { - req := smsgateway.MobileChangePasswordRequest{} + req := new(smsgateway.MobileChangePasswordRequest) - if err := h.BodyParserValidator(c, &req); err != nil { - return err + if err := h.BodyParserValidator(c, req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil { @@ -229,9 +233,13 @@ func (h *mobileHandler) Register(router fiber.Router) { // 2. User is already authenticated - allowing device registration for existing users return h.authSvc.IsPublic() || userauth.HasUser(c) }, - Validator: func(c *fiber.Ctx, token string) (bool, error) { + Validator: func(_ *fiber.Ctx, token string) (bool, error) { err := h.authSvc.AuthorizeRegistration(token) - return err == nil, err + if err != nil { + return false, fmt.Errorf("authorization failed: %w", err) + } + + return true, nil }, }), h.postDevice, @@ -264,7 +272,8 @@ func (h *mobileHandler) Register(router fiber.Router) { } func newMobileHandler(params mobileHandlerParams) *mobileHandler { - idGen, _ := nanoid.Standard(21) + const idGenSize = 21 + idGen, _ := nanoid.Standard(idGenSize) return &mobileHandler{ Handler: base.Handler{Logger: params.Logger, Validator: params.Validator}, diff --git a/internal/sms-gateway/handlers/module.go b/internal/sms-gateway/handlers/module.go index 52eb7173..83440302 100644 --- a/internal/sms-gateway/handlers/module.go +++ b/internal/sms-gateway/handlers/module.go @@ -12,28 +12,30 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "handlers", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("handlers") - }), - fx.Provide( - http.AsRootHandler(newRootHandler), - http.AsApiHandler(newThirdPartyHandler), - http.AsApiHandler(newMobileHandler), - http.AsApiHandler(newUpstreamHandler), - ), - fx.Provide( - NewHealthHandler, - messages.NewThirdPartyController, - messages.NewMobileController, - webhooks.NewThirdPartyController, - webhooks.NewMobileController, - devices.NewThirdPartyController, - settings.NewThirdPartyController, - settings.NewMobileController, - logs.NewThirdPartyController, - events.NewMobileController, - fx.Private, - ), -) +func Module() fx.Option { + return fx.Module( + "handlers", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("handlers") + }), + fx.Provide( + http.AsRootHandler(newRootHandler), + http.AsApiHandler(newThirdPartyHandler), + http.AsApiHandler(newMobileHandler), + http.AsApiHandler(newUpstreamHandler), + ), + fx.Provide( + NewHealthHandler, + messages.NewThirdPartyController, + messages.NewMobileController, + webhooks.NewThirdPartyController, + webhooks.NewMobileController, + devices.NewThirdPartyController, + settings.NewThirdPartyController, + settings.NewMobileController, + logs.NewThirdPartyController, + events.NewMobileController, + fx.Private, + ), + ) +} diff --git a/internal/sms-gateway/handlers/root.go b/internal/sms-gateway/handlers/root.go index 778d28a2..434f7e8c 100644 --- a/internal/sms-gateway/handlers/root.go +++ b/internal/sms-gateway/handlers/root.go @@ -25,7 +25,7 @@ func (h *rootHandler) Register(app *fiber.App) { c.Set(fiber.HeaderLocation, path.Join(h.config.PublicPath, after)) } - return err + return err //nolint:wrapcheck // passed through to fiber's error handler }) } diff --git a/internal/sms-gateway/handlers/settings/3rdparty.go b/internal/sms-gateway/handlers/settings/3rdparty.go index a95879de..6c784915 100644 --- a/internal/sms-gateway/handlers/settings/3rdparty.go +++ b/internal/sms-gateway/handlers/settings/3rdparty.go @@ -32,6 +32,17 @@ type ThirdPartyController struct { settingsSvc *settings.Service } +func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { + return &ThirdPartyController{ + Handler: base.Handler{ + Logger: params.Logger, + Validator: params.Validator, + }, + devicesSvc: params.DevicesSvc, + settingsSvc: params.SettingsSvc, + } +} + // @Summary Get settings // @Description Returns settings for a specific user // @Security ApiAuth @@ -42,11 +53,11 @@ type ThirdPartyController struct { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/settings [get] // -// Get settings +// Get settings. func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { settings, err := h.settingsSvc.GetSettings(user.ID, true) if err != nil { - return fmt.Errorf("can't get settings: %w", err) + return fmt.Errorf("failed to get settings: %w", err) } return c.JSON(settings) @@ -65,22 +76,22 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/settings [put] // -// Update settings +// Update settings. func (h *ThirdPartyController) put(user models.User, c *fiber.Ctx) error { - if err := h.BodyParserValidator(c, &smsgateway.DeviceSettings{}); err != nil { + if err := h.BodyParserValidator(c, new(smsgateway.DeviceSettings)); err != nil { return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Invalid settings format: %v", err)) } - settings := make(map[string]any, 8) + settings := make(map[string]any) if err := c.BodyParser(&settings); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Failed to parse request body: %v", err)) + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse request body: %v", err)) } updated, err := h.settingsSvc.ReplaceSettings(user.ID, settings) if err != nil { - return fmt.Errorf("can't update settings: %w", err) + return fmt.Errorf("failed to update settings: %w", err) } return c.JSON(updated) @@ -99,21 +110,21 @@ func (h *ThirdPartyController) put(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/settings [patch] // -// Partially update settings +// Partially update settings. func (h *ThirdPartyController) patch(user models.User, c *fiber.Ctx) error { - if err := h.BodyParserValidator(c, &smsgateway.DeviceSettings{}); err != nil { + if err := h.BodyParserValidator(c, new(smsgateway.DeviceSettings)); err != nil { return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Invalid settings format: %v", err)) } - settings := make(map[string]any, 8) + settings := make(map[string]any) if err := c.BodyParser(&settings); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Failed to parse request body: %v", err)) + return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse request body: %v", err)) } updated, err := h.settingsSvc.UpdateSettings(user.ID, settings) if err != nil { - return fmt.Errorf("can't update settings: %w", err) + return fmt.Errorf("failed to update settings: %w", err) } return c.JSON(updated) @@ -124,14 +135,3 @@ func (h *ThirdPartyController) Register(app fiber.Router) { app.Patch("", userauth.WithUser(h.patch)) app.Put("", userauth.WithUser(h.put)) } - -func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { - return &ThirdPartyController{ - Handler: base.Handler{ - Logger: params.Logger.Named("settings"), - Validator: params.Validator, - }, - devicesSvc: params.DevicesSvc, - settingsSvc: params.SettingsSvc, - } -} diff --git a/internal/sms-gateway/handlers/settings/mobile.go b/internal/sms-gateway/handlers/settings/mobile.go index 24ad84db..38c4fb32 100644 --- a/internal/sms-gateway/handlers/settings/mobile.go +++ b/internal/sms-gateway/handlers/settings/mobile.go @@ -1,27 +1,16 @@ package settings import ( - "fmt" - "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/deviceauth" "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/settings" + "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" - "go.uber.org/fx" "go.uber.org/zap" ) -type mobileControllerParams struct { - fx.In - - DevicesSvc *devices.Service - SettingsSvc *settings.Service - - Logger *zap.Logger -} - type MobileController struct { base.Handler @@ -29,6 +18,22 @@ type MobileController struct { settingsSvc *settings.Service } +func NewMobileController( + devicesSvc *devices.Service, + settingsSvc *settings.Service, + logger *zap.Logger, + validator *validator.Validate, +) *MobileController { + return &MobileController{ + Handler: base.Handler{ + Logger: logger, + Validator: validator, + }, + devicesSvc: devicesSvc, + settingsSvc: settingsSvc, + } +} + // @Summary Get settings // @Description Returns settings for a device // @Security MobileToken @@ -39,11 +44,17 @@ type MobileController struct { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/settings [get] // -// Get settings +// Get settings. func (h *MobileController) get(device models.Device, c *fiber.Ctx) error { settings, err := h.settingsSvc.GetSettings(device.UserID, false) if err != nil { - return fmt.Errorf("can't get settings for device %s (user ID: %s): %w", device.ID, device.UserID, err) + h.Logger.Error( + "failed to get settings", + zap.Error(err), + zap.String("device_id", device.ID), + zap.String("user_id", device.UserID), + ) + return fiber.NewError(fiber.StatusInternalServerError, "failed to get settings") } return c.JSON(settings) @@ -52,13 +63,3 @@ func (h *MobileController) get(device models.Device, c *fiber.Ctx) error { func (h *MobileController) Register(router fiber.Router) { router.Get("", deviceauth.WithDevice(h.get)) } - -func NewMobileController(params mobileControllerParams) *MobileController { - return &MobileController{ - Handler: base.Handler{ - Logger: params.Logger.Named("settings"), - }, - devicesSvc: params.DevicesSvc, - settingsSvc: params.SettingsSvc, - } -} diff --git a/internal/sms-gateway/handlers/upstream.go b/internal/sms-gateway/handlers/upstream.go index fb010c8a..4690147d 100644 --- a/internal/sms-gateway/handlers/upstream.go +++ b/internal/sms-gateway/handlers/upstream.go @@ -51,7 +51,7 @@ func newUpstreamHandler(params upstreamHandlerParams) *upstreamHandler { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /upstream/v1/push [post] // -// Send push notifications +// Send push notifications. func (h *upstreamHandler) postPush(c *fiber.Ctx) error { req := smsgateway.UpstreamPushRequest{} @@ -65,7 +65,7 @@ func (h *upstreamHandler) postPush(c *fiber.Ctx) error { for _, v := range req { if err := h.ValidateStruct(v); err != nil { - return err + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } event := push.Event{ @@ -74,7 +74,7 @@ func (h *upstreamHandler) postPush(c *fiber.Ctx) error { } if err := h.pushSvc.Enqueue(v.Token, event); err != nil { - h.Logger.Error("Can't push message", zap.Error(err)) + h.Logger.Error("failed to push message", zap.Error(err)) } } @@ -94,9 +94,13 @@ func (h *upstreamHandler) Register(router fiber.Router) { router = router.Group("/upstream/v1") + const ( + rateLimit = 5 + rateTime = 60 * time.Second + ) router.Post("/push", limiter.New(limiter.Config{ - Max: 5, - Expiration: 60 * time.Second, + Max: rateLimit, + Expiration: rateTime, LimiterMiddleware: limiter.SlidingWindow{}, }), h.postPush) } diff --git a/internal/sms-gateway/handlers/webhooks/3rdparty.go b/internal/sms-gateway/handlers/webhooks/3rdparty.go index 450f0430..a7152c17 100644 --- a/internal/sms-gateway/handlers/webhooks/3rdparty.go +++ b/internal/sms-gateway/handlers/webhooks/3rdparty.go @@ -29,6 +29,16 @@ type ThirdPartyController struct { webhooksSvc *webhooks.Service } +func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { + return &ThirdPartyController{ + Handler: base.Handler{ + Logger: params.Logger, + Validator: params.Validator, + }, + webhooksSvc: params.WebhooksSvc, + } +} + // @Summary List webhooks // @Description Returns list of registered webhooks // @Security ApiAuth @@ -39,11 +49,11 @@ type ThirdPartyController struct { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/webhooks [get] // -// List webhooks +// List webhooks. func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { items, err := h.webhooksSvc.Select(user.ID) if err != nil { - return fmt.Errorf("can't select webhooks: %w", err) + return fmt.Errorf("failed to select webhooks: %w", err) } return c.JSON(items) @@ -62,12 +72,12 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/webhooks [post] // -// Register webhook +// Register webhook. func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { - dto := smsgateway.Webhook{} + dto := new(smsgateway.Webhook) - if err := h.BodyParserValidator(c, &dto); err != nil { - return err + if err := h.BodyParserValidator(c, dto); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) } if err := h.webhooksSvc.Replace(user.ID, dto); err != nil { @@ -75,7 +85,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { return fiber.NewError(fiber.StatusBadRequest, err.Error()) } - return fmt.Errorf("can't write webhook: %w", err) + return fmt.Errorf("failed to write webhook: %w", err) } return c.Status(fiber.StatusCreated).JSON(dto) @@ -92,12 +102,12 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/webhooks/{id} [delete] // -// Delete webhook +// Delete webhook. func (h *ThirdPartyController) delete(user models.User, c *fiber.Ctx) error { id := c.Params("id") if err := h.webhooksSvc.Delete(user.ID, webhooks.WithExtID(id)); err != nil { - return fmt.Errorf("can't delete webhook: %w", err) + return fmt.Errorf("failed to delete webhook: %w", err) } return c.SendStatus(fiber.StatusNoContent) @@ -108,13 +118,3 @@ func (h *ThirdPartyController) Register(router fiber.Router) { router.Post("", userauth.WithUser(h.post)) router.Delete("/:id", userauth.WithUser(h.delete)) } - -func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyController { - return &ThirdPartyController{ - Handler: base.Handler{ - Logger: params.Logger.Named("webhooks"), - Validator: params.Validator, - }, - webhooksSvc: params.WebhooksSvc, - } -} diff --git a/internal/sms-gateway/handlers/webhooks/mobile.go b/internal/sms-gateway/handlers/webhooks/mobile.go index 8d89dc1d..bfbaefd7 100644 --- a/internal/sms-gateway/handlers/webhooks/mobile.go +++ b/internal/sms-gateway/handlers/webhooks/mobile.go @@ -7,25 +7,31 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/deviceauth" "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/webhooks" + "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" - "go.uber.org/fx" "go.uber.org/zap" ) -type mobileControllerParams struct { - fx.In - - WebhooksServices *webhooks.Service - - Logger *zap.Logger -} - type MobileController struct { base.Handler webhooksSvc *webhooks.Service } +func NewMobileController( + webhooksSvc *webhooks.Service, + logger *zap.Logger, + validator *validator.Validate, +) *MobileController { + return &MobileController{ + Handler: base.Handler{ + Logger: logger, + Validator: validator, + }, + webhooksSvc: webhooksSvc, + } +} + // @Summary List webhooks // @Description Returns list of registered webhooks for device // @Security MobileToken @@ -36,11 +42,11 @@ type MobileController struct { // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /mobile/v1/webhooks [get] // -// List webhooks +// List webhooks. func (h *MobileController) get(device models.Device, c *fiber.Ctx) error { items, err := h.webhooksSvc.Select(device.UserID, webhooks.WithDeviceID(device.ID, false)) if err != nil { - return fmt.Errorf("can't select webhooks: %w", err) + return fmt.Errorf("failed to select webhooks: %w", err) } return c.JSON(items) @@ -49,12 +55,3 @@ func (h *MobileController) get(device models.Device, c *fiber.Ctx) error { func (h *MobileController) Register(router fiber.Router) { router.Get("", deviceauth.WithDevice(h.get)) } - -func NewMobileController(params mobileControllerParams) *MobileController { - return &MobileController{ - Handler: base.Handler{ - Logger: params.Logger.Named("webhooks"), - }, - webhooksSvc: params.WebhooksServices, - } -} diff --git a/internal/sms-gateway/models/migration.go b/internal/sms-gateway/models/migration.go index 1ab9851c..aee9173a 100644 --- a/internal/sms-gateway/models/migration.go +++ b/internal/sms-gateway/models/migration.go @@ -2,6 +2,7 @@ package models import ( "embed" + "fmt" "gorm.io/gorm" ) @@ -10,5 +11,8 @@ import ( var migrations embed.FS func Migrate(db *gorm.DB) error { - return db.AutoMigrate(&User{}, &Device{}) + if err := db.AutoMigrate(new(User), new(Device)); err != nil { + return fmt.Errorf("models migration failed: %w", err) + } + return nil } diff --git a/internal/sms-gateway/models/models.go b/internal/sms-gateway/models/models.go index 058963ea..dd28bb42 100644 --- a/internal/sms-gateway/models/models.go +++ b/internal/sms-gateway/models/models.go @@ -11,18 +11,29 @@ type TimedModel struct { type SoftDeletableModel struct { TimedModel + DeletedAt *time.Time `gorm:"<-:update"` } type User struct { + SoftDeletableModel + ID string `gorm:"primaryKey;type:varchar(32)"` PasswordHash string `gorm:"not null;type:varchar(72)"` Devices []Device `gorm:"-,foreignKey:UserID;constraint:OnDelete:CASCADE"` +} - SoftDeletableModel +func NewUser(id, passwordHash string) *User { + //nolint:exhaustruct // pertial constructor + return &User{ + ID: id, + PasswordHash: passwordHash, + } } type Device struct { + SoftDeletableModel + ID string `gorm:"primaryKey;type:char(21)"` Name *string `gorm:"type:varchar(128)"` AuthToken string `gorm:"not null;uniqueIndex;type:char(21)"` @@ -31,8 +42,14 @@ type Device struct { LastSeen time.Time `gorm:"not null;autocreatetime:false;default:CURRENT_TIMESTAMP(3);index:idx_devices_last_seen"` UserID string `gorm:"not null;type:varchar(32)"` +} - SoftDeletableModel +func NewDevice(name, pushToken *string) *Device { + //nolint:exhaustruct // partial constructor + return &Device{ + Name: name, + PushToken: pushToken, + } } func (d *Device) IsEmpty() bool { diff --git a/internal/sms-gateway/models/module.go b/internal/sms-gateway/models/module.go index de6510f1..62582b38 100644 --- a/internal/sms-gateway/models/module.go +++ b/internal/sms-gateway/models/module.go @@ -4,6 +4,7 @@ import ( "github.com/capcom6/go-infra-fx/db" ) +//nolint:gochecknoinits // framework-specific func init() { db.RegisterMigration(Migrate) db.RegisterGoose(migrations) diff --git a/internal/sms-gateway/modules/db/health.go b/internal/sms-gateway/modules/db/health.go index ea4f3bc3..0e9078ee 100644 --- a/internal/sms-gateway/modules/db/health.go +++ b/internal/sms-gateway/modules/db/health.go @@ -17,6 +17,8 @@ type health struct { func newHealth(db *sql.DB) *health { return &health{ db: db, + + failedPings: atomic.Int64{}, } } @@ -26,8 +28,8 @@ func (h *health) Name() string { } // LiveProbe implements HealthProvider. -func (h *health) LiveProbe(ctx context.Context) (healthmod.Checks, error) { - return nil, nil +func (h *health) LiveProbe(_ context.Context) (healthmod.Checks, error) { + return nil, nil //nolint:nilnil // empty result } // ReadyProbe implements HealthProvider. @@ -53,8 +55,8 @@ func (h *health) ReadyProbe(ctx context.Context) (healthmod.Checks, error) { } // StartedProbe implements HealthProvider. -func (h *health) StartedProbe(ctx context.Context) (healthmod.Checks, error) { - return nil, nil +func (h *health) StartedProbe(_ context.Context) (healthmod.Checks, error) { + return nil, nil //nolint:nilnil // empty result } -var _ healthmod.HealthProvider = (*health)(nil) +var _ healthmod.Provider = (*health)(nil) diff --git a/internal/sms-gateway/modules/db/module.go b/internal/sms-gateway/modules/db/module.go index 7a932e3d..e155581a 100644 --- a/internal/sms-gateway/modules/db/module.go +++ b/internal/sms-gateway/modules/db/module.go @@ -7,6 +7,10 @@ import ( healthmod "github.com/android-sms-gateway/server/pkg/health" ) +const ( + idSize = 21 +) + type IDGen func() string func Module() fx.Option { @@ -16,7 +20,7 @@ func Module() fx.Option { healthmod.AsHealthProvider(newHealth), ), fx.Provide(func() (IDGen, error) { - return nanoid.Standard(21) + return nanoid.Standard(idSize) }), ) } diff --git a/internal/sms-gateway/modules/devices/repository.go b/internal/sms-gateway/modules/devices/repository.go index 6081b5b5..4164e017 100644 --- a/internal/sms-gateway/modules/devices/repository.go +++ b/internal/sms-gateway/modules/devices/repository.go @@ -11,7 +11,7 @@ import ( ) var ( - ErrNotFound = gorm.ErrRecordNotFound + ErrNotFound = errors.New("record not found") ErrInvalidFilter = errors.New("invalid filter") ErrMoreThanOne = errors.New("more than one record") ) @@ -43,7 +43,7 @@ func (r *Repository) Select(filter ...SelectFilter) ([]models.Device, error) { // error during the query, it returns false and the error. Otherwise, it returns // true and nil error. func (r *Repository) Exists(filters ...SelectFilter) (bool, error) { - err := newFilter(filters...).apply(r.db).Take(&models.Device{}).Error + err := newFilter(filters...).apply(r.db).Take(new(models.Device)).Error if errors.Is(err, gorm.ErrRecordNotFound) { return false, nil } @@ -75,7 +75,7 @@ func (r *Repository) Insert(device *models.Device) error { } func (r *Repository) UpdatePushToken(id, token string) error { - res := r.db.Model(&models.Device{}).Where("id = ?", id).Update("push_token", token) + res := r.db.Model((*models.Device)(nil)).Where("id = ?", id).Update("push_token", token) if res.Error != nil { return fmt.Errorf("failed to update device: %w", res.Error) } @@ -88,7 +88,7 @@ func (r *Repository) SetLastSeen(ctx context.Context, id string, lastSeen time.T return nil // ignore zero timestamps } res := r.db.WithContext(ctx). - Model(&models.Device{}). + Model((*models.Device)(nil)). Where("id = ? AND last_seen < ?", id, lastSeen). UpdateColumn("last_seen", lastSeen) if res.Error != nil { @@ -105,14 +105,14 @@ func (r *Repository) Remove(filter ...SelectFilter) error { } f := newFilter(filter...) - return f.apply(r.db).Delete(&models.Device{}).Error + return f.apply(r.db).Delete(new(models.Device)).Error } func (r *Repository) Cleanup(ctx context.Context, until time.Time) (int64, error) { res := r.db. WithContext(ctx). Where("last_seen < ?", until). - Delete(&models.Device{}) + Delete(new(models.Device)) return res.RowsAffected, res.Error } diff --git a/internal/sms-gateway/modules/devices/repository_filter.go b/internal/sms-gateway/modules/devices/repository_filter.go index c86c3fb9..b20aaab1 100644 --- a/internal/sms-gateway/modules/devices/repository_filter.go +++ b/internal/sms-gateway/modules/devices/repository_filter.go @@ -40,7 +40,7 @@ type selectFilter struct { } func newFilter(filters ...SelectFilter) *selectFilter { - f := &selectFilter{} + f := new(selectFilter) f.merge(filters...) return f } diff --git a/internal/sms-gateway/modules/devices/service.go b/internal/sms-gateway/modules/devices/service.go index d306ecc1..d433539a 100644 --- a/internal/sms-gateway/modules/devices/service.go +++ b/internal/sms-gateway/modules/devices/service.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "math/rand/v2" "time" "github.com/android-sms-gateway/server/internal/sms-gateway/models" @@ -76,6 +77,35 @@ func (s *Service) Get(userID string, filter ...SelectFilter) (models.Device, err return s.devices.Get(filter...) } +func (s *Service) GetAny(userID string, deviceID string, duration time.Duration) (*models.Device, error) { + filter := []SelectFilter{ + WithUserID(userID), + } + if deviceID != "" { + filter = append(filter, WithID(deviceID)) + } + if duration > 0 { + filter = append(filter, ActiveWithin(duration)) + } + + devices, err := s.devices.Select(filter...) + if err != nil { + return nil, err + } + + if len(devices) == 0 { + return nil, ErrNotFound + } + + if len(devices) == 1 { + return &devices[0], nil + } + + idx := rand.IntN(len(devices)) //nolint:gosec //not critical + + return &devices[idx], nil +} + // GetByToken returns a device by token. // // This method is used to retrieve a device by its auth token. If the device @@ -88,8 +118,8 @@ func (s *Service) GetByToken(token string) (models.Device, error) { return device, err } - if err := s.cache.Set(device); err != nil { - s.logger.Error("can't cache device", zap.String("device_id", device.ID), zap.Error(err)) + if setErr := s.cache.Set(device); setErr != nil { + s.logger.Error("failed to cache device", zap.String("device_id", device.ID), zap.Error(setErr)) } } @@ -98,7 +128,7 @@ func (s *Service) GetByToken(token string) (models.Device, error) { func (s *Service) UpdatePushToken(id string, token string) error { if err := s.cache.DeleteByID(id); err != nil { - s.logger.Error("can't invalidate cache", + s.logger.Error("failed to invalidate cache", zap.String("device_id", id), zap.Error(err), ) @@ -123,7 +153,7 @@ func (s *Service) SetLastSeen(ctx context.Context, batch map[string]time.Time) e } if err := s.devices.SetLastSeen(ctx, deviceID, lastSeen); err != nil { multiErr = errors.Join(multiErr, fmt.Errorf("device %s: %w", deviceID, err)) - s.logger.Error("can't set last seen", + s.logger.Error("failed to set last seen", zap.String("device_id", deviceID), zap.Time("last_seen", lastSeen), zap.Error(err), @@ -147,16 +177,16 @@ func (s *Service) Remove(userID string, filter ...SelectFilter) error { } for _, device := range devices { - if err := s.cache.DeleteByID(device.ID); err != nil { - s.logger.Error("can't invalidate cache", + if cacheErr := s.cache.DeleteByID(device.ID); cacheErr != nil { + s.logger.Error("failed to invalidate cache", zap.String("device_id", device.ID), - zap.Error(err), + zap.Error(cacheErr), ) } } - if err := s.devices.Remove(filter...); err != nil { - return err + if rmErr := s.devices.Remove(filter...); rmErr != nil { + return rmErr } return nil diff --git a/internal/sms-gateway/modules/events/errors.go b/internal/sms-gateway/modules/events/errors.go new file mode 100644 index 00000000..6d5684a4 --- /dev/null +++ b/internal/sms-gateway/modules/events/errors.go @@ -0,0 +1,7 @@ +package events + +import "errors" + +var ( + ErrValidationFailed = errors.New("validation failed") +) diff --git a/internal/sms-gateway/modules/events/metrics.go b/internal/sms-gateway/modules/events/metrics.go index e2734679..37b28db3 100644 --- a/internal/sms-gateway/modules/events/metrics.go +++ b/internal/sms-gateway/modules/events/metrics.go @@ -5,7 +5,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -// Metric constants +// Metric constants. const ( MetricEnqueuedTotal = "enqueued_total" MetricSentTotal = "sent_total" @@ -26,14 +26,14 @@ const ( EventTypeUnknown = "unknown" ) -// metrics contains all Prometheus metrics for the events module +// metrics contains all Prometheus metrics for the events module. type metrics struct { enqueuedCounter *prometheus.CounterVec sentCounter *prometheus.CounterVec failedCounter *prometheus.CounterVec } -// newMetrics creates and initializes all events metrics +// newMetrics creates and initializes all events metrics. func newMetrics() *metrics { return &metrics{ enqueuedCounter: promauto.NewCounterVec(prometheus.CounterOpts{ @@ -57,17 +57,17 @@ func newMetrics() *metrics { } } -// IncrementEnqueued increments the enqueued counter for the given event type +// IncrementEnqueued increments the enqueued counter for the given event type. func (m *metrics) IncrementEnqueued(eventType string) { m.enqueuedCounter.WithLabelValues(eventType).Inc() } -// IncrementSent increments the sent counter for the given event type and delivery type +// IncrementSent increments the sent counter for the given event type and delivery type. func (m *metrics) IncrementSent(eventType string, deliveryType string) { m.sentCounter.WithLabelValues(eventType, deliveryType).Inc() } -// IncrementFailed increments the failed counter for the given event type, delivery type, and reason +// IncrementFailed increments the failed counter for the given event type, delivery type, and reason. func (m *metrics) IncrementFailed(eventType string, deliveryType string, reason string) { m.failedCounter.WithLabelValues(eventType, deliveryType, reason).Inc() } diff --git a/internal/sms-gateway/modules/events/module.go b/internal/sms-gateway/modules/events/module.go index 8f6cf71a..96f07049 100644 --- a/internal/sms-gateway/modules/events/module.go +++ b/internal/sms-gateway/modules/events/module.go @@ -7,31 +7,33 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "events", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("events") - }), - fx.Provide(newMetrics, fx.Private), - fx.Provide(NewService), - fx.Invoke(func(lc fx.Lifecycle, svc *Service, logger *zap.Logger, sh fx.Shutdowner) { - ctx, cancel := context.WithCancel(context.Background()) - lc.Append(fx.Hook{ - OnStart: func(_ context.Context) error { - go func() { - if err := svc.Run(ctx); err != nil { - logger.Error("Error running events service", zap.Error(err)) - if err := sh.Shutdown(fx.ExitCode(1)); err != nil { - logger.Error("Failed to shutdown", zap.Error(err)) +func Module() fx.Option { + return fx.Module( + "events", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("events") + }), + fx.Provide(newMetrics, fx.Private), + fx.Provide(NewService), + fx.Invoke(func(lc fx.Lifecycle, svc *Service, logger *zap.Logger, sh fx.Shutdowner) { + ctx, cancel := context.WithCancel(context.Background()) + lc.Append(fx.Hook{ + OnStart: func(_ context.Context) error { + go func() { + if err := svc.Run(ctx); err != nil { + logger.Error("error running events service", zap.Error(err)) + if shErr := sh.Shutdown(fx.ExitCode(1)); shErr != nil { + logger.Error("failed to shutdown", zap.Error(shErr)) + } } - } - }() - return nil - }, - OnStop: func(_ context.Context) error { - cancel() - return nil - }, - }) - }), -) + }() + return nil + }, + OnStop: func(_ context.Context) error { + cancel() + return nil + }, + }) + }), + ) +} diff --git a/internal/sms-gateway/modules/events/service.go b/internal/sms-gateway/modules/events/service.go index 384dbfa1..7bfac39b 100644 --- a/internal/sms-gateway/modules/events/service.go +++ b/internal/sms-gateway/modules/events/service.go @@ -13,7 +13,8 @@ import ( ) const ( - pubsubTopic = "events" + pubsubTopic = "events" + pubsubTimeout = 5 * time.Second ) type Service struct { @@ -29,7 +30,14 @@ type Service struct { logger *zap.Logger } -func NewService(devicesSvc *devices.Service, sseSvc *sse.Service, pushSvc *push.Service, pubsub pubsub.PubSub, metrics *metrics, logger *zap.Logger) *Service { +func NewService( + devicesSvc *devices.Service, + sseSvc *sse.Service, + pushSvc *push.Service, + pubsub pubsub.PubSub, + metrics *metrics, + logger *zap.Logger, +) *Service { return &Service{ deviceSvc: devicesSvc, sseSvc: sseSvc, @@ -45,10 +53,10 @@ func NewService(devicesSvc *devices.Service, sseSvc *sse.Service, pushSvc *push. func (s *Service) Notify(userID string, deviceID *string, event Event) error { if event.EventType == "" { - return fmt.Errorf("event type is empty") + return fmt.Errorf("%w: event type is empty", ErrValidationFailed) } - subCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + subCtx, cancel := context.WithTimeout(context.Background(), pubsubTimeout) defer cancel() wrapper := eventWrapper{ @@ -60,12 +68,12 @@ func (s *Service) Notify(userID string, deviceID *string, event Event) error { wrapperBytes, err := wrapper.serialize() if err != nil { s.metrics.IncrementFailed(string(event.EventType), DeliveryTypeUnknown, FailureReasonSerializationError) - return fmt.Errorf("can't serialize event wrapper: %w", err) + return fmt.Errorf("failed to serialize event wrapper: %w", err) } - if err := s.pubsub.Publish(subCtx, pubsubTopic, wrapperBytes); err != nil { + if pubErr := s.pubsub.Publish(subCtx, pubsubTopic, wrapperBytes); pubErr != nil { s.metrics.IncrementFailed(string(event.EventType), DeliveryTypeUnknown, FailureReasonPublishError) - return fmt.Errorf("can't publish event: %w", err) + return fmt.Errorf("failed to publish event: %w", pubErr) } s.metrics.IncrementEnqueued(string(event.EventType)) @@ -76,7 +84,7 @@ func (s *Service) Notify(userID string, deviceID *string, event Event) error { func (s *Service) Run(ctx context.Context) error { sub, err := s.pubsub.Subscribe(ctx, pubsubTopic) if err != nil { - return fmt.Errorf("can't subscribe to pubsub: %w", err) + return fmt.Errorf("failed to subscribe to pubsub: %w", err) } defer sub.Close() @@ -92,9 +100,9 @@ func (s *Service) Run(ctx context.Context) error { return nil } wrapper := new(eventWrapper) - if err := wrapper.deserialize(msg.Data); err != nil { + if jsonErr := wrapper.deserialize(msg.Data); jsonErr != nil { s.metrics.IncrementFailed(EventTypeUnknown, DeliveryTypeUnknown, FailureReasonSerializationError) - s.logger.Error("Failed to deserialize event wrapper", zap.Error(err)) + s.logger.Error("failed to deserialize event wrapper", zap.Error(jsonErr)) continue } s.processEvent(wrapper) @@ -111,12 +119,12 @@ func (s *Service) processEvent(wrapper *eventWrapper) { devices, err := s.deviceSvc.Select(wrapper.UserID, filters...) if err != nil { - s.logger.Error("Failed to select devices", zap.String("user_id", wrapper.UserID), zap.Error(err)) + s.logger.Error("failed to select devices", zap.String("user_id", wrapper.UserID), zap.Error(err)) return } if len(devices) == 0 { - s.logger.Info("No devices found for user", zap.String("user_id", wrapper.UserID)) + s.logger.Info("no devices found for user", zap.String("user_id", wrapper.UserID)) return } @@ -124,12 +132,21 @@ func (s *Service) processEvent(wrapper *eventWrapper) { for _, device := range devices { if device.PushToken != nil && *device.PushToken != "" { // Device has push token, use push service - if err := s.pushSvc.Enqueue(*device.PushToken, push.Event{ + if enqErr := s.pushSvc.Enqueue(*device.PushToken, push.Event{ Type: wrapper.Event.EventType, Data: wrapper.Event.Data, - }); err != nil { - s.logger.Error("Failed to enqueue push notification", zap.String("user_id", wrapper.UserID), zap.String("device_id", device.ID), zap.Error(err)) - s.metrics.IncrementFailed(string(wrapper.Event.EventType), DeliveryTypePush, FailureReasonProviderFailed) + }); enqErr != nil { + s.logger.Error( + "failed to enqueue push notification", + zap.String("user_id", wrapper.UserID), + zap.String("device_id", device.ID), + zap.Error(enqErr), + ) + s.metrics.IncrementFailed( + string(wrapper.Event.EventType), + DeliveryTypePush, + FailureReasonProviderFailed, + ) } else { s.metrics.IncrementSent(string(wrapper.Event.EventType), DeliveryTypePush) } @@ -137,11 +154,16 @@ func (s *Service) processEvent(wrapper *eventWrapper) { } // No push token, use SSE service - if err := s.sseSvc.Send(device.ID, sse.Event{ + if sseErr := s.sseSvc.Send(device.ID, sse.Event{ Type: wrapper.Event.EventType, Data: wrapper.Event.Data, - }); err != nil { - s.logger.Error("Failed to send SSE notification", zap.String("user_id", wrapper.UserID), zap.String("device_id", device.ID), zap.Error(err)) + }); sseErr != nil { + s.logger.Error( + "failed to send SSE notification", + zap.String("user_id", wrapper.UserID), + zap.String("device_id", device.ID), + zap.Error(sseErr), + ) s.metrics.IncrementFailed(string(wrapper.Event.EventType), DeliveryTypeSSE, FailureReasonProviderFailed) } else { s.metrics.IncrementSent(string(wrapper.Event.EventType), DeliveryTypeSSE) diff --git a/internal/sms-gateway/modules/events/types.go b/internal/sms-gateway/modules/events/types.go index 76e4d89e..feb77ac0 100644 --- a/internal/sms-gateway/modules/events/types.go +++ b/internal/sms-gateway/modules/events/types.go @@ -2,6 +2,7 @@ package events import ( "encoding/json" + "fmt" "github.com/android-sms-gateway/client-go/smsgateway" ) @@ -25,9 +26,18 @@ type eventWrapper struct { } func (w *eventWrapper) serialize() ([]byte, error) { - return json.Marshal(w) + data, err := json.Marshal(w) + if err != nil { + return nil, fmt.Errorf("failed to marshal event: %w", err) + } + + return data, nil } func (w *eventWrapper) deserialize(data []byte) error { - return json.Unmarshal(data, w) + if err := json.Unmarshal(data, w); err != nil { + return fmt.Errorf("failed to unmarshal event: %w", err) + } + + return nil } diff --git a/internal/sms-gateway/modules/messages/cache.go b/internal/sms-gateway/modules/messages/cache.go index ca708623..ca072e62 100644 --- a/internal/sms-gateway/modules/messages/cache.go +++ b/internal/sms-gateway/modules/messages/cache.go @@ -27,7 +27,7 @@ func newCache(config Config, storage cacheImpl.Cache) *cache { } } -func (c *cache) Set(ctx context.Context, userID, ID string, message *MessageStateOut) error { +func (c *cache) Set(ctx context.Context, userID, id string, message *MessageStateOut) error { var ( err error data []byte @@ -36,32 +36,36 @@ func (c *cache) Set(ctx context.Context, userID, ID string, message *MessageStat if message != nil { data, err = json.Marshal(message) if err != nil { - return fmt.Errorf("can't marshal message: %w", err) + return fmt.Errorf("failed to marshal message: %w", err) } } ctx, cancel := context.WithTimeout(ctx, cacheTimeout) defer cancel() - return c.storage.Set(ctx, userID+":"+ID, data, cacheImpl.WithTTL(c.ttl)) + if setErr := c.storage.Set(ctx, userID+":"+id, data, cacheImpl.WithTTL(c.ttl)); setErr != nil { + return fmt.Errorf("failed to set message in cache: %w", setErr) + } + + return nil } -func (c *cache) Get(ctx context.Context, userID, ID string) (*MessageStateOut, error) { +func (c *cache) Get(ctx context.Context, userID, id string) (*MessageStateOut, error) { ctx, cancel := context.WithTimeout(ctx, cacheTimeout) defer cancel() - data, err := c.storage.Get(ctx, userID+":"+ID, cacheImpl.AndSetTTL(c.ttl)) + data, err := c.storage.Get(ctx, userID+":"+id, cacheImpl.AndSetTTL(c.ttl)) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get message from cache: %w", err) } if len(data) == 0 { - return nil, nil + return nil, nil //nolint:nilnil //empty cached value is used for caching "Not Found" } message := new(MessageStateOut) - if err := json.Unmarshal(data, message); err != nil { - return nil, fmt.Errorf("can't unmarshal message: %w", err) + if jsonErr := json.Unmarshal(data, message); jsonErr != nil { + return nil, fmt.Errorf("failed to unmarshal message: %w", jsonErr) } return message, nil diff --git a/internal/sms-gateway/modules/messages/converters.go b/internal/sms-gateway/modules/messages/converters.go index 221dd4dd..b636e8db 100644 --- a/internal/sms-gateway/modules/messages/converters.go +++ b/internal/sms-gateway/modules/messages/converters.go @@ -10,7 +10,7 @@ import ( ) func messageToDomain(input Message) (MessageOut, error) { - var ttl *uint64 = nil + var ttl *uint64 if input.ValidUntil != nil { secondsUntil := uint64(math.Max(0, time.Until(*input.ValidUntil).Seconds())) ttl = &secondsUntil @@ -18,11 +18,11 @@ func messageToDomain(input Message) (MessageOut, error) { textContent, err := input.GetTextContent() if err != nil { - return MessageOut{}, fmt.Errorf("can't get text content: %w", err) + return MessageOut{}, fmt.Errorf("failed to get text content: %w", err) } dataContent, err := input.GetDataContent() if err != nil { - return MessageOut{}, fmt.Errorf("can't get data content: %w", err) + return MessageOut{}, fmt.Errorf("failed to get data content: %w", err) } return MessageOut{ diff --git a/internal/sms-gateway/modules/messages/errors.go b/internal/sms-gateway/modules/messages/errors.go index 195139de..570c8fd9 100644 --- a/internal/sms-gateway/modules/messages/errors.go +++ b/internal/sms-gateway/modules/messages/errors.go @@ -1,7 +1,13 @@ package messages -type ErrValidation string +import "errors" -func (e ErrValidation) Error() string { +var ( + ErrNoContent = errors.New("no text or data content") +) + +type ValidationError string + +func (e ValidationError) Error() string { return string(e) } diff --git a/internal/sms-gateway/modules/messages/models.go b/internal/sms-gateway/modules/messages/models.go index b64c6af5..d643dba3 100644 --- a/internal/sms-gateway/modules/messages/models.go +++ b/internal/sms-gateway/modules/messages/models.go @@ -6,6 +6,7 @@ import ( "time" "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "github.com/samber/lo" "gorm.io/gorm" ) @@ -33,6 +34,8 @@ type DataMessageContent struct { } type Message struct { + models.SoftDeletableModel + ID uint64 `gorm:"primaryKey;type:BIGINT UNSIGNED;autoIncrement"` DeviceID string `gorm:"not null;type:char(21);uniqueIndex:unq_messages_id_device,priority:2;index:idx_messages_device_state"` ExtID string `gorm:"not null;type:varchar(36);uniqueIndex:unq_messages_id_device,priority:1"` @@ -50,14 +53,39 @@ type Message struct { Device models.Device `gorm:"foreignKey:DeviceID;constraint:OnDelete:CASCADE"` Recipients []MessageRecipient `gorm:"foreignKey:MessageID;constraint:OnDelete:CASCADE"` States []MessageState `gorm:"foreignKey:MessageID;constraint:OnDelete:CASCADE"` +} - models.SoftDeletableModel +func NewMessage( + extID string, + deviceID string, + phoneNumbers []string, + priority int8, + simNumber *uint8, + validUntil *time.Time, + withDeliveryReport bool, + isEncrypted bool, +) *Message { + //nolint:exhaustruct // partial constructor + return &Message{ + ExtID: extID, + DeviceID: deviceID, + Recipients: lo.Map(phoneNumbers, func(item string, _ int) MessageRecipient { + return newMessageRecipient(item, ProcessingStatePending, nil) + }), + Priority: priority, + SimNumber: simNumber, + ValidUntil: validUntil, + WithDeliveryReport: withDeliveryReport, + IsEncrypted: isEncrypted, + + State: ProcessingStatePending, + } } func (m *Message) SetTextContent(content TextMessageContent) error { contentJSON, err := json.Marshal(content) if err != nil { - return err + return fmt.Errorf("failed to marshal: %w", err) } m.Type = MessageTypeText @@ -68,23 +96,23 @@ func (m *Message) SetTextContent(content TextMessageContent) error { func (m *Message) GetTextContent() (*TextMessageContent, error) { if m.Type != MessageTypeText { - return nil, nil + return nil, nil //nolint:nilnil // special meaning } - content := TextMessageContent{} + content := new(TextMessageContent) - err := json.Unmarshal([]byte(m.Content), &content) + err := json.Unmarshal([]byte(m.Content), content) if err != nil { return nil, fmt.Errorf("failed to unmarshal text content: %w", err) } - return &content, nil + return content, nil } func (m *Message) SetDataContent(content DataMessageContent) error { contentJSON, err := json.Marshal(content) if err != nil { - return err + return fmt.Errorf("failed to marshal: %w", err) } m.Type = MessageTypeData @@ -95,17 +123,17 @@ func (m *Message) SetDataContent(content DataMessageContent) error { func (m *Message) GetDataContent() (*DataMessageContent, error) { if m.Type != MessageTypeData { - return nil, nil + return nil, nil //nolint:nilnil // special meaning } - content := DataMessageContent{} + content := new(DataMessageContent) - err := json.Unmarshal([]byte(m.Content), &content) + err := json.Unmarshal([]byte(m.Content), content) if err != nil { return nil, fmt.Errorf("failed to unmarshal data content: %w", err) } - return &content, nil + return content, nil } type MessageRecipient struct { @@ -116,6 +144,16 @@ type MessageRecipient struct { Error *string `gorm:"type:varchar(256)"` } +func newMessageRecipient(phoneNumber string, state ProcessingState, err *string) MessageRecipient { + return MessageRecipient{ + ID: 0, + MessageID: 0, + PhoneNumber: phoneNumber, + State: state, + Error: err, + } +} + type MessageState struct { ID uint64 `gorm:"primaryKey;type:BIGINT UNSIGNED;autoIncrement"` MessageID uint64 `gorm:"not null;type:BIGINT UNSIGNED;uniqueIndex:unq_message_states_message_id_state,priority:1"` @@ -124,5 +162,8 @@ type MessageState struct { } func Migrate(db *gorm.DB) error { - return db.AutoMigrate(&Message{}, &MessageRecipient{}, &MessageState{}) + if err := db.AutoMigrate(new(Message), new(MessageRecipient), new(MessageState)); err != nil { + return fmt.Errorf("messages migration failed: %w", err) + } + return nil } diff --git a/internal/sms-gateway/modules/messages/module.go b/internal/sms-gateway/modules/messages/module.go index ec60ae71..9f73ffb2 100644 --- a/internal/sms-gateway/modules/messages/module.go +++ b/internal/sms-gateway/modules/messages/module.go @@ -24,6 +24,7 @@ func Module() fx.Option { ) } +//nolint:gochecknoinits //backward compatibility func init() { db.RegisterMigration(Migrate) } diff --git a/internal/sms-gateway/modules/messages/repository.go b/internal/sms-gateway/modules/messages/repository.go index 6f781598..f8ae1ff4 100644 --- a/internal/sms-gateway/modules/messages/repository.go +++ b/internal/sms-gateway/modules/messages/repository.go @@ -13,7 +13,7 @@ import ( const maxPendingBatch = 100 -var ErrMessageNotFound = gorm.ErrRecordNotFound +var ErrMessageNotFound = errors.New("message not found") var ErrMessageAlreadyExists = errors.New("duplicate id") var ErrMultipleMessagesFound = errors.New("multiple messages found") @@ -27,8 +27,8 @@ func NewRepository(db *gorm.DB) *Repository { } } -func (r *Repository) Select(filter MessagesSelectFilter, options MessagesSelectOptions) ([]Message, int64, error) { - query := r.db.Model(&Message{}) +func (r *Repository) Select(filter SelectFilter, options SelectOptions) ([]Message, int64, error) { + query := r.db.Model((*Message)(nil)) // Apply date range filter if !filter.StartDate.IsZero() { @@ -94,29 +94,25 @@ func (r *Repository) Select(filter MessagesSelectFilter, options MessagesSelectO messages := make([]Message, 0, min(options.Limit, int(total))) if err := query.Find(&messages).Error; err != nil { - return nil, 0, fmt.Errorf("can't select messages: %w", err) + return nil, 0, fmt.Errorf("failed to select messages: %w", err) } return messages, total, nil } -func (r *Repository) SelectPending(deviceID string, order MessagesOrder) ([]Message, error) { - messages, _, err := r.Select(MessagesSelectFilter{ - DeviceID: deviceID, - State: ProcessingStatePending, - }, MessagesSelectOptions{ - WithRecipients: true, - Limit: maxPendingBatch, - OrderBy: order, - }) +func (r *Repository) SelectPending(deviceID string, order Order) ([]Message, error) { + messages, _, err := r.Select( + *new(SelectFilter).WithDeviceID(deviceID).WithState(ProcessingStatePending), + *new(SelectOptions).IncludeRecipients().WithLimit(maxPendingBatch).WithOrderBy(order), + ) return messages, err } -func (r *Repository) Get(filter MessagesSelectFilter, options MessagesSelectOptions) (Message, error) { +func (r *Repository) Get(filter SelectFilter, options SelectOptions) (Message, error) { messages, _, err := r.Select(filter, options) if err != nil { - return Message{}, fmt.Errorf("can't get message: %w", err) + return Message{}, fmt.Errorf("failed to get message: %w", err) } if len(messages) == 0 { @@ -144,7 +140,7 @@ func (r *Repository) Insert(message *Message) error { } func (r *Repository) UpdateState(message *Message) error { - return r.db.Transaction(func(tx *gorm.DB) error { + err := r.db.Transaction(func(tx *gorm.DB) error { if err := tx.Model(message).Select("State").Updates(message).Error; err != nil { return err } @@ -159,7 +155,7 @@ func (r *Repository) UpdateState(message *Message) error { } for _, v := range message.Recipients { - if err := tx.Model(&MessageRecipient{}). + if err := tx.Model((*MessageRecipient)(nil)). Where("message_id = ? AND phone_number = ?", message.ID, v.PhoneNumber). Select("state", "error"). Updates(map[string]any{"state": v.State, "error": v.Error}).Error; err != nil { @@ -169,6 +165,12 @@ func (r *Repository) UpdateState(message *Message) error { return nil }) + + if err != nil { + return fmt.Errorf("failed to update message state: %w", err) + } + + return nil } func (r *Repository) HashProcessed(ctx context.Context, ids []uint64) (int64, error) { @@ -195,6 +197,6 @@ func (r *Repository) Cleanup(ctx context.Context, until time.Time) (int64, error WithContext(ctx). Where("state <> ?", ProcessingStatePending). Where("created_at < ?", until). - Delete(&Message{}) + Delete(new(Message)) return res.RowsAffected, res.Error } diff --git a/internal/sms-gateway/modules/messages/repository_filter.go b/internal/sms-gateway/modules/messages/repository_filter.go index 6b9b9216..24bd9d2f 100644 --- a/internal/sms-gateway/modules/messages/repository_filter.go +++ b/internal/sms-gateway/modules/messages/repository_filter.go @@ -2,18 +2,18 @@ package messages import "time" -// MessagesOrder defines supported ordering for message selection. +// Order defines supported ordering for message selection. // Valid values: "lifo" (default), "fifo". -type MessagesOrder string +type Order string const ( // MessagesOrderLIFO orders messages newest-first within the same priority (default). - MessagesOrderLIFO MessagesOrder = "lifo" + MessagesOrderLIFO Order = "lifo" // MessagesOrderFIFO orders messages oldest-first within the same priority. - MessagesOrderFIFO MessagesOrder = "fifo" + MessagesOrderFIFO Order = "fifo" ) -type MessagesSelectFilter struct { +type SelectFilter struct { ExtID string UserID string DeviceID string @@ -22,15 +22,71 @@ type MessagesSelectFilter struct { State ProcessingState } -type MessagesSelectOptions struct { +func (f *SelectFilter) WithExtID(extID string) *SelectFilter { + f.ExtID = extID + return f +} + +func (f *SelectFilter) WithUserID(userID string) *SelectFilter { + f.UserID = userID + return f +} + +func (f *SelectFilter) WithDeviceID(deviceID string) *SelectFilter { + f.DeviceID = deviceID + return f +} + +func (f *SelectFilter) WithDateRange(start, end time.Time) *SelectFilter { + f.StartDate = start + f.EndDate = end + return f +} + +func (f *SelectFilter) WithState(state ProcessingState) *SelectFilter { + f.State = state + return f +} + +type SelectOptions struct { WithRecipients bool WithDevice bool WithStates bool // OrderBy sets the retrieval order for pending messages. // Empty (zero) value defaults to "lifo". - OrderBy MessagesOrder + OrderBy Order Limit int Offset int } + +func (o *SelectOptions) WithLimit(limit int) *SelectOptions { + o.Limit = limit + return o +} + +func (o *SelectOptions) WithOffset(offset int) *SelectOptions { + o.Offset = offset + return o +} + +func (o *SelectOptions) WithOrderBy(order Order) *SelectOptions { + o.OrderBy = order + return o +} + +func (o *SelectOptions) IncludeRecipients() *SelectOptions { + o.WithRecipients = true + return o +} + +func (o *SelectOptions) IncludeDevice() *SelectOptions { + o.WithDevice = true + return o +} + +func (o *SelectOptions) IncludeStates() *SelectOptions { + o.WithStates = true + return o +} diff --git a/internal/sms-gateway/modules/messages/service.go b/internal/sms-gateway/modules/messages/service.go index ecb67ee7..6a4c3653 100644 --- a/internal/sms-gateway/modules/messages/service.go +++ b/internal/sms-gateway/modules/messages/service.go @@ -15,12 +15,8 @@ import ( "github.com/capcom6/go-helpers/anys" "github.com/capcom6/go-helpers/slices" "github.com/nyaruka/phonenumbers" + "github.com/samber/lo" "go.uber.org/zap" - "golang.org/x/exp/maps" -) - -const ( - ErrorTTLExpired = "TTL expired" ) type EnqueueOptions struct { @@ -74,7 +70,7 @@ func (s *Service) RunBackgroundTasks(ctx context.Context, wg *sync.WaitGroup) { }() } -func (s *Service) SelectPending(deviceID string, order MessagesOrder) ([]MessageOut, error) { +func (s *Service) SelectPending(deviceID string, order Order) ([]MessageOut, error) { if order == "" { order = MessagesOrderLIFO } @@ -84,11 +80,14 @@ func (s *Service) SelectPending(deviceID string, order MessagesOrder) ([]Message return nil, err } - return slices.MapOrError(messages, messageToDomain) + return slices.MapOrError(messages, messageToDomain) //nolint:wrapcheck // already wrapped } func (s *Service) UpdateState(device *models.Device, message MessageStateIn) error { - existing, err := s.messages.Get(MessagesSelectFilter{ExtID: message.ID, DeviceID: device.ID}, MessagesSelectOptions{}) + existing, err := s.messages.Get( + *new(SelectFilter).WithExtID(message.ID).WithDeviceID(device.ID), + SelectOptions{}, //nolint:exhaustruct // not needed + ) if err != nil { return err } @@ -98,21 +97,25 @@ func (s *Service) UpdateState(device *models.Device, message MessageStateIn) err } existing.State = message.State - existing.States = slices.Map(maps.Keys(message.States), func(key string) MessageState { - return MessageState{ - MessageID: existing.ID, - State: ProcessingState(key), - UpdatedAt: message.States[key], - } - }) + existing.States = lo.MapToSlice( + message.States, + func(key string, value time.Time) MessageState { + return MessageState{ + ID: 0, + MessageID: existing.ID, + State: ProcessingState(key), + UpdatedAt: value, + } + }, + ) existing.Recipients = s.recipientsStateToModel(message.Recipients, existing.IsHashed) - if err := s.messages.UpdateState(&existing); err != nil { - return err + if updErr := s.messages.UpdateState(&existing); updErr != nil { + return updErr } - if err := s.cache.Set(context.Background(), device.UserID, existing.ExtID, anys.AsPointer(modelToMessageState(existing))); err != nil { - s.logger.Warn("can't cache message", zap.String("id", existing.ExtID), zap.Error(err)) + if cacheErr := s.cache.Set(context.Background(), device.UserID, existing.ExtID, anys.AsPointer(modelToMessageState(existing))); cacheErr != nil { + s.logger.Warn("failed to cache message", zap.String("id", existing.ExtID), zap.Error(cacheErr)) } s.hashingWorker.Enqueue(existing.ID) s.metrics.IncTotal(string(existing.State)) @@ -120,19 +123,23 @@ func (s *Service) UpdateState(device *models.Device, message MessageStateIn) err return nil } -func (s *Service) SelectStates(user models.User, filter MessagesSelectFilter, options MessagesSelectOptions) ([]MessageStateOut, int64, error) { +func (s *Service) SelectStates( + user models.User, + filter SelectFilter, + options SelectOptions, +) ([]MessageStateOut, int64, error) { filter.UserID = user.ID messages, total, err := s.messages.Select(filter, options) if err != nil { - return nil, 0, fmt.Errorf("can't select messages: %w", err) + return nil, 0, fmt.Errorf("failed to select messages: %w", err) } return slices.Map(messages, modelToMessageState), total, nil } -func (s *Service) GetState(user models.User, ID string) (*MessageStateOut, error) { - dto, err := s.cache.Get(context.Background(), user.ID, ID) +func (s *Service) GetState(user models.User, id string) (*MessageStateOut, error) { + dto, err := s.cache.Get(context.Background(), user.ID, id) if err == nil { s.metrics.IncCache(true) @@ -145,13 +152,13 @@ func (s *Service) GetState(user models.User, ID string) (*MessageStateOut, error s.metrics.IncCache(false) message, err := s.messages.Get( - MessagesSelectFilter{ExtID: ID, UserID: user.ID}, - MessagesSelectOptions{WithRecipients: true, WithDevice: true, WithStates: true}, + *new(SelectFilter).WithExtID(id).WithUserID(user.ID), + *new(SelectOptions).IncludeRecipients().IncludeDevice().IncludeStates(), ) if err != nil { if errors.Is(err, ErrMessageNotFound) { - if err := s.cache.Set(context.Background(), user.ID, ID, nil); err != nil { - s.logger.Warn("can't cache message", zap.String("id", ID), zap.Error(err)) + if cacheErr := s.cache.Set(context.Background(), user.ID, id, nil); cacheErr != nil { + s.logger.Warn("failed to cache message", zap.String("id", id), zap.Error(cacheErr)) } } @@ -159,22 +166,58 @@ func (s *Service) GetState(user models.User, ID string) (*MessageStateOut, error } dto = anys.AsPointer(modelToMessageState(message)) - if err := s.cache.Set(context.Background(), user.ID, ID, dto); err != nil { - s.logger.Warn("can't cache message", zap.String("id", ID), zap.Error(err)) + if cacheErr := s.cache.Set(context.Background(), user.ID, id, dto); cacheErr != nil { + s.logger.Warn("failed to cache message", zap.String("id", id), zap.Error(cacheErr)) } return dto, nil } -func (s *Service) Enqueue(device models.Device, message MessageIn, opts EnqueueOptions) (MessageStateOut, error) { - state := MessageStateOut{ +func (s *Service) Enqueue(device models.Device, message MessageIn, opts EnqueueOptions) (*MessageStateOut, error) { + msg, err := s.prepareMessage(device, message, opts) + if err != nil { + return nil, err + } + + state := &MessageStateOut{ DeviceID: device.ID, MessageStateIn: MessageStateIn{ - State: ProcessingStatePending, - Recipients: make([]smsgateway.RecipientState, len(message.PhoneNumbers)), + ID: msg.ExtID, + State: ProcessingStatePending, + Recipients: lo.Map( + msg.Recipients, + func(item MessageRecipient, _ int) smsgateway.RecipientState { return modelToRecipientState(item) }, + ), + States: map[string]time.Time{}, }, + IsHashed: false, + IsEncrypted: msg.IsEncrypted, + } + + if insErr := s.messages.Insert(msg); insErr != nil { + return state, insErr + } + + if cacheErr := s.cache.Set(context.Background(), device.UserID, msg.ExtID, anys.AsPointer(modelToMessageState(*msg))); cacheErr != nil { + s.logger.Warn("failed to cache message", zap.String("id", msg.ExtID), zap.Error(cacheErr)) } + s.metrics.IncTotal(string(msg.State)) + + go func(userID, deviceID string) { + if ntfErr := s.eventsSvc.Notify(userID, &deviceID, events.NewMessageEnqueuedEvent()); ntfErr != nil { + s.logger.Error( + "failed to notify device", + zap.Error(ntfErr), + zap.String("user_id", userID), + zap.String("device_id", deviceID), + ) + } + }(device.UserID, device.ID) + + return state, nil +} +func (s *Service) prepareMessage(device models.Device, message MessageIn, opts EnqueueOptions) (*Message, error) { var phone string var err error for i, v := range message.PhoneNumbers { @@ -182,92 +225,64 @@ func (s *Service) Enqueue(device models.Device, message MessageIn, opts EnqueueO phone = v } else { if phone, err = cleanPhoneNumber(v); err != nil { - return state, fmt.Errorf("can't use phone in row %d: %w", i+1, err) + return nil, fmt.Errorf("failed to use phone in row %d: %w", i+1, err) } } message.PhoneNumbers[i] = phone - - state.Recipients[i] = smsgateway.RecipientState{ - PhoneNumber: phone, - State: smsgateway.ProcessingStatePending, - } } validUntil := message.ValidUntil if message.TTL != nil && *message.TTL > 0 { - validUntil = anys.AsPointer(time.Now().Add(time.Duration(*message.TTL) * time.Second)) + //nolint:gosec // not a problem + validUntil = anys.AsPointer( + time.Now().Add(time.Duration(*message.TTL) * time.Second), + ) } - msg := Message{ - ExtID: message.ID, - Recipients: s.recipientsToModel(message.PhoneNumbers), - IsEncrypted: message.IsEncrypted, - - DeviceID: device.ID, - - SimNumber: message.SimNumber, - WithDeliveryReport: anys.OrDefault(message.WithDeliveryReport, true), - - Priority: int8(message.Priority), - ValidUntil: validUntil, - } + msg := NewMessage( + message.ID, + device.ID, + message.PhoneNumbers, + int8(message.Priority), + message.SimNumber, + validUntil, + anys.OrDefault(message.WithDeliveryReport, true), + message.IsEncrypted, + ) - if message.TextContent != nil { - if err := msg.SetTextContent(*message.TextContent); err != nil { - return state, fmt.Errorf("can't set text content: %w", err) + switch { + case message.TextContent != nil: + if setErr := msg.SetTextContent(*message.TextContent); setErr != nil { + return nil, fmt.Errorf("failed to set text content: %w", setErr) } - } else if message.DataContent != nil { - if err := msg.SetDataContent(*message.DataContent); err != nil { - return state, fmt.Errorf("can't set data content: %w", err) + case message.DataContent != nil: + if setErr := msg.SetDataContent(*message.DataContent); setErr != nil { + return nil, fmt.Errorf("failed to set data content: %w", setErr) } - } else { - return state, errors.New("no text or data content") + default: + return nil, ErrNoContent } if msg.ExtID == "" { msg.ExtID = s.idgen() } - state.ID = msg.ExtID - - if err := s.messages.Insert(&msg); err != nil { - return state, err - } - - if err := s.cache.Set(context.Background(), device.UserID, message.ID, anys.AsPointer(modelToMessageState(msg))); err != nil { - s.logger.Warn("can't cache message", zap.String("id", msg.ExtID), zap.Error(err)) - } - s.metrics.IncTotal(string(msg.State)) - - go func(userID, deviceID string) { - if err := s.eventsSvc.Notify(userID, &deviceID, events.NewMessageEnqueuedEvent()); err != nil { - s.logger.Error("can't notify device", zap.Error(err), zap.String("user_id", userID), zap.String("device_id", deviceID)) - } - }(device.UserID, device.ID) - return state, nil + return msg, nil } func (s *Service) ExportInbox(device models.Device, since, until time.Time) error { event := events.NewMessagesExportRequestedEvent(since, until) - return s.eventsSvc.Notify(device.UserID, &device.ID, event) -} - -/////////////////////////////////////////////////////////////////////////////// - -func (s *Service) recipientsToModel(input []string) []MessageRecipient { - output := make([]MessageRecipient, len(input)) - - for i, v := range input { - output[i] = MessageRecipient{ - PhoneNumber: v, - } + if err := s.eventsSvc.Notify(device.UserID, &device.ID, event); err != nil { + return fmt.Errorf("failed to notify device: %w", err) } - return output + return nil } +/////////////////////////////////////////////////////////////////////////////// + func (s *Service) recipientsStateToModel(input []smsgateway.RecipientState, hash bool) []MessageRecipient { output := make([]MessageRecipient, len(input)) @@ -286,11 +301,11 @@ func (s *Service) recipientsStateToModel(input []smsgateway.RecipientState, hash phoneNumber = fmt.Sprintf("%x", sha256.Sum256([]byte(phoneNumber)))[:16] } - output[i] = MessageRecipient{ - PhoneNumber: phoneNumber, - State: ProcessingState(v.State), - Error: v.Error, - } + output[i] = newMessageRecipient( + phoneNumber, + ProcessingState(v.State), + v.Error, + ) } return output @@ -326,16 +341,16 @@ func modelToRecipientState(input MessageRecipient) smsgateway.RecipientState { func cleanPhoneNumber(input string) (string, error) { phone, err := phonenumbers.Parse(input, "RU") if err != nil { - return input, ErrValidation(fmt.Sprintf("can't parse phone number: %s", err.Error())) + return input, ValidationError(fmt.Sprintf("failed to parse phone number: %s", err.Error())) } if !phonenumbers.IsValidNumber(phone) { - return input, ErrValidation("invalid phone number") + return input, ValidationError("invalid phone number") } phoneNumberType := phonenumbers.GetNumberType(phone) if phoneNumberType != phonenumbers.MOBILE && phoneNumberType != phonenumbers.FIXED_LINE_OR_MOBILE { - return input, ErrValidation("not mobile phone number") + return input, ValidationError("not mobile phone number") } return phonenumbers.Format(phone, phonenumbers.E164), nil diff --git a/internal/sms-gateway/modules/messages/service_test.go b/internal/sms-gateway/modules/messages/service_test.go deleted file mode 100644 index 1113b70a..00000000 --- a/internal/sms-gateway/modules/messages/service_test.go +++ /dev/null @@ -1,158 +0,0 @@ -package messages - -import ( - "reflect" - "testing" - - "github.com/android-sms-gateway/client-go/smsgateway" -) - -func TestService_recipientsStateToModel(t *testing.T) { - type args struct { - input []smsgateway.RecipientState - hash bool - } - tests := []struct { - name string - s *Service - args args - want []MessageRecipient - }{ - { - name: "Without +", - s: &Service{}, - args: args{ - input: []smsgateway.RecipientState{ - { - PhoneNumber: "79990001234", - State: "", - }, - }, - }, - want: []MessageRecipient{ - { - MessageID: 0, - PhoneNumber: "+79990001234", - State: "", - }, - }, - }, - { - name: "With +", - s: &Service{}, - args: args{ - input: []smsgateway.RecipientState{ - { - PhoneNumber: "+79990001234", - State: "", - }, - }, - }, - want: []MessageRecipient{ - { - MessageID: 0, - PhoneNumber: "+79990001234", - State: "", - }, - }, - }, - { - name: "With hashing", - s: &Service{}, - args: args{ - input: []smsgateway.RecipientState{ - { - PhoneNumber: "+79990001234", - State: "", - }, - }, - hash: true, - }, - want: []MessageRecipient{ - { - MessageID: 0, - PhoneNumber: "62d17792b45c5307", - State: "", - }, - }, - }, - { - name: "Empty phone", - s: &Service{}, - args: args{ - input: []smsgateway.RecipientState{ - { - PhoneNumber: "", - State: "", - }, - }, - }, - want: []MessageRecipient{ - { - MessageID: 0, - PhoneNumber: "", - State: "", - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := tt.s.recipientsStateToModel(tt.args.input, tt.args.hash); !reflect.DeepEqual(got, tt.want) { - t.Errorf("MessagesService.recipientsStateToModel() = %v, want %v", got, tt.want) - } - }) - } -} - -func TestCleanPhoneNumber(t *testing.T) { - tests := []struct { - name string - input string - expected string - expectError bool - }{ - { - name: "Valid number with validation", - input: "+79161234567", - expected: "+79161234567", - expectError: false, - }, - { - name: "Invalid number with validation", - input: "+123!@#", - expected: "", - expectError: true, - }, - { - name: "Empty input with validation", - input: "", - expected: "", - expectError: true, - }, - { - name: "Long number with validation", - input: "+345906566798696", - expected: "", - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := cleanPhoneNumber(tt.input) - if tt.expectError { - if err == nil { - t.Errorf("Expected error, got nil") - } - } else { - if err != nil { - t.Errorf("Unexpected error: %v", err) - } - if result != tt.expected { - t.Errorf("Expected %s, got %s", tt.expected, result) - } - } - }) - } -} diff --git a/internal/sms-gateway/modules/messages/workers.go b/internal/sms-gateway/modules/messages/workers.go index ec039338..0fbf6901 100644 --- a/internal/sms-gateway/modules/messages/workers.go +++ b/internal/sms-gateway/modules/messages/workers.go @@ -2,11 +2,13 @@ package messages import ( "context" + "slices" "sync" "time" + "maps" + "go.uber.org/zap" - "golang.org/x/exp/maps" ) type hashingWorker struct { @@ -27,6 +29,7 @@ func newHashingWorker(config Config, messages *Repository, logger *zap.Logger) * logger: logger, queue: map[uint64]struct{}{}, + mux: sync.Mutex{}, } } @@ -46,7 +49,7 @@ func (t *hashingWorker) Run(ctx context.Context) { } } -// Enqueue adds a message ID to the processing queue to be hashed in the next batch +// Enqueue adds a message ID to the processing queue to be hashed in the next batch. func (t *hashingWorker) Enqueue(id uint64) { t.mux.Lock() t.queue[id] = struct{}{} @@ -56,8 +59,8 @@ func (t *hashingWorker) Enqueue(id uint64) { func (t *hashingWorker) process(ctx context.Context) { t.mux.Lock() - ids := maps.Keys(t.queue) - maps.Clear(t.queue) + ids := slices.AppendSeq(make([]uint64, 0, len(t.queue)), maps.Keys(t.queue)) + clear(t.queue) t.mux.Unlock() @@ -67,6 +70,6 @@ func (t *hashingWorker) process(ctx context.Context) { t.logger.Debug("Hashing messages...") if _, err := t.messages.HashProcessed(ctx, ids); err != nil { - t.logger.Error("Can't hash messages", zap.Error(err)) + t.logger.Error("failed to hash messages", zap.Error(err)) } } diff --git a/internal/sms-gateway/modules/metrics/handler.go b/internal/sms-gateway/modules/metrics/handler.go index e2281795..b62a56ba 100644 --- a/internal/sms-gateway/modules/metrics/handler.go +++ b/internal/sms-gateway/modules/metrics/handler.go @@ -5,16 +5,16 @@ import ( "github.com/gofiber/fiber/v2" ) -type HttpHandler struct { +type HTTPHandler struct { } -func (h *HttpHandler) Register(app *fiber.App) { +func (h *HTTPHandler) Register(app *fiber.App) { promhandler := fiberprometheus.New("") promhandler.RegisterAt(app, "/metrics") app.Use(promhandler.Middleware) } -func newHttpHandler() *HttpHandler { - return &HttpHandler{} +func newHTTPHandler() *HTTPHandler { + return &HTTPHandler{} } diff --git a/internal/sms-gateway/modules/metrics/module.go b/internal/sms-gateway/modules/metrics/module.go index 98790267..ace82c08 100644 --- a/internal/sms-gateway/modules/metrics/module.go +++ b/internal/sms-gateway/modules/metrics/module.go @@ -6,12 +6,14 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "metrics", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("metrics") - }), - fx.Provide( - http.AsRootHandler(newHttpHandler), - ), -) +func Module() fx.Option { + return fx.Module( + "metrics", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("metrics") + }), + fx.Provide( + http.AsRootHandler(newHTTPHandler), + ), + ) +} diff --git a/internal/sms-gateway/modules/push/client.go b/internal/sms-gateway/modules/push/client.go new file mode 100644 index 00000000..4a5a027a --- /dev/null +++ b/internal/sms-gateway/modules/push/client.go @@ -0,0 +1,34 @@ +package push + +import ( + "errors" + "fmt" + + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/fcm" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/upstream" +) + +var ErrInvalidPushMode = errors.New("invalid push mode") + +func newClient(config Config) (client.Client, error) { + var ( + c client.Client + err error + ) + + switch config.Mode { + case ModeFCM: + c, err = fcm.New(config.ClientOptions) + case ModeUpstream: + c, err = upstream.New(config.ClientOptions) + default: + return nil, fmt.Errorf("%w: %s", ErrInvalidPushMode, config.Mode) + } + + if err != nil { + return nil, fmt.Errorf("failed to create client: %w", err) + } + + return c, nil +} diff --git a/internal/sms-gateway/modules/push/types/types.go b/internal/sms-gateway/modules/push/client/types.go similarity index 55% rename from internal/sms-gateway/modules/push/types/types.go rename to internal/sms-gateway/modules/push/client/types.go index 48312a59..68553df1 100644 --- a/internal/sms-gateway/modules/push/types/types.go +++ b/internal/sms-gateway/modules/push/client/types.go @@ -1,9 +1,17 @@ -package types +package client import ( + "context" + "github.com/android-sms-gateway/client-go/smsgateway" ) +type Client interface { + Open(ctx context.Context) error + Send(ctx context.Context, messages []Message) ([]error, error) + Close(ctx context.Context) error +} + type Message struct { Token string Event Event diff --git a/internal/sms-gateway/modules/push/fcm/client.go b/internal/sms-gateway/modules/push/fcm/client.go index 44649544..10b63701 100644 --- a/internal/sms-gateway/modules/push/fcm/client.go +++ b/internal/sms-gateway/modules/push/fcm/client.go @@ -7,7 +7,7 @@ import ( firebase "firebase.google.com/go/v4" "firebase.google.com/go/v4/messaging" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/types" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" "google.golang.org/api/option" ) @@ -21,6 +21,8 @@ type Client struct { func New(options map[string]string) (*Client, error) { return &Client{ options: options, + client: nil, + mux: sync.Mutex{}, }, nil } @@ -34,31 +36,31 @@ func (c *Client) Open(ctx context.Context) error { creds := c.options["credentials"] if creds == "" { - return fmt.Errorf("no credentials provided") + return fmt.Errorf("%w: no credentials provided", ErrInitializationFailed) } opt := option.WithCredentialsJSON([]byte(creds)) app, err := firebase.NewApp(ctx, nil, opt) if err != nil { - return fmt.Errorf("can't create firebase app: %w", err) + return fmt.Errorf("%w: failed to create firebase app: %w", ErrInitializationFailed, err) } c.client, err = app.Messaging(ctx) if err != nil { - return fmt.Errorf("can't create firebase messaging client: %w", err) + return fmt.Errorf("%w: failed to create firebase messaging client: %w", ErrInitializationFailed, err) } return nil } -func (c *Client) Send(ctx context.Context, messages []types.Message) ([]error, error) { +func (c *Client) Send(ctx context.Context, messages []client.Message) ([]error, error) { errs := make([]error, len(messages)) for i, message := range messages { data, err := eventToMap(message.Event) if err != nil { - errs[i] = fmt.Errorf("can't marshal event: %w", err) + errs[i] = fmt.Errorf("failed to marshal event: %w", err) continue } @@ -70,14 +72,17 @@ func (c *Client) Send(ctx context.Context, messages []types.Message) ([]error, e Token: message.Token, }) if err != nil { - errs[i] = fmt.Errorf("can't send message: %w", err) + errs[i] = fmt.Errorf("failed to send message: %w", err) } } return errs, nil } -func (c *Client) Close(ctx context.Context) error { +func (c *Client) Close(_ context.Context) error { + c.mux.Lock() + defer c.mux.Unlock() + c.client = nil return nil diff --git a/internal/sms-gateway/modules/push/fcm/errors.go b/internal/sms-gateway/modules/push/fcm/errors.go new file mode 100644 index 00000000..278ca92f --- /dev/null +++ b/internal/sms-gateway/modules/push/fcm/errors.go @@ -0,0 +1,7 @@ +package fcm + +import "errors" + +var ( + ErrInitializationFailed = errors.New("initialization failed") +) diff --git a/internal/sms-gateway/modules/push/fcm/utils.go b/internal/sms-gateway/modules/push/fcm/utils.go index d8179505..5b8c9c19 100644 --- a/internal/sms-gateway/modules/push/fcm/utils.go +++ b/internal/sms-gateway/modules/push/fcm/utils.go @@ -4,13 +4,13 @@ import ( "encoding/json" "fmt" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/types" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" ) -func eventToMap(event types.Event) (map[string]string, error) { +func eventToMap(event client.Event) (map[string]string, error) { json, err := json.Marshal(event.Data) if err != nil { - return nil, fmt.Errorf("can't marshal event data: %w", err) + return nil, fmt.Errorf("failed to marshal event data: %w", err) } return map[string]string{ diff --git a/internal/sms-gateway/modules/push/module.go b/internal/sms-gateway/modules/push/module.go index e2e25951..556fe602 100644 --- a/internal/sms-gateway/modules/push/module.go +++ b/internal/sms-gateway/modules/push/module.go @@ -2,10 +2,8 @@ package push import ( "context" - "fmt" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/fcm" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/upstream" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" "go.uber.org/fx" "go.uber.org/zap" ) @@ -18,35 +16,21 @@ func Module() fx.Option { }), fx.Provide(newMetrics, fx.Private), fx.Provide( - func(cfg Config, lc fx.Lifecycle) (c client, err error) { - switch cfg.Mode { - case ModeFCM: - c, err = fcm.New(cfg.ClientOptions) - case ModeUpstream: - c, err = upstream.New(cfg.ClientOptions) - default: - return nil, fmt.Errorf("invalid push mode: %q", cfg.Mode) - } - - if err != nil { - return nil, err - } - - lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { - return c.Open(ctx) - }, - OnStop: func(ctx context.Context) error { - return c.Close(ctx) - }, - }) - - return c, nil - }, + newClient, fx.Private, ), fx.Provide( New, ), + fx.Invoke(func(lc fx.Lifecycle, c client.Client) { + lc.Append(fx.Hook{ + OnStart: func(ctx context.Context) error { + return c.Open(ctx) + }, + OnStop: func(ctx context.Context) error { + return c.Close(ctx) + }, + }) + }), ) } diff --git a/internal/sms-gateway/modules/push/service.go b/internal/sms-gateway/modules/push/service.go index 207b1b34..da6190f8 100644 --- a/internal/sms-gateway/modules/push/service.go +++ b/internal/sms-gateway/modules/push/service.go @@ -6,7 +6,7 @@ import ( "time" "github.com/android-sms-gateway/server/internal/sms-gateway/cache" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/types" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" cacheImpl "github.com/android-sms-gateway/server/pkg/cache" "github.com/samber/lo" @@ -16,6 +16,8 @@ import ( const ( cachePrefixEvents = "events:" cachePrefixBlacklist = "blacklist:" + + defaultDebounce = 5 * time.Second ) type Config struct { @@ -30,7 +32,7 @@ type Config struct { type Service struct { config Config - client client + client client.Client events cache.Cache blacklist cache.Cache @@ -40,23 +42,23 @@ type Service struct { func New( config Config, - client client, + client client.Client, cacheFactory cache.Factory, metrics *metrics, logger *zap.Logger, ) (*Service, error) { events, err := cacheFactory.New(cachePrefixEvents) if err != nil { - return nil, fmt.Errorf("can't create events cache: %w", err) + return nil, fmt.Errorf("failed to create events cache: %w", err) } blacklist, err := cacheFactory.New(cachePrefixBlacklist) if err != nil { - return nil, fmt.Errorf("can't create blacklist cache: %w", err) + return nil, fmt.Errorf("failed to create blacklist cache: %w", err) } config.Timeout = max(config.Timeout, time.Second) - config.Debounce = max(config.Debounce, 5*time.Second) + config.Debounce = max(config.Debounce, defaultDebounce) return &Service{ config: config, @@ -87,7 +89,7 @@ func (s *Service) Run(ctx context.Context) { } // Enqueue adds the data to the cache and immediately sends all messages if the debounce is 0. -func (s *Service) Enqueue(token string, event types.Event) error { +func (s *Service) Enqueue(token string, event Event) error { ctx, cancel := context.WithTimeout(context.Background(), s.config.Timeout) defer cancel() @@ -105,12 +107,12 @@ func (s *Service) Enqueue(token string, event types.Event) error { wrapperData, err := wrapper.serialize() if err != nil { s.metrics.IncError(1) - return fmt.Errorf("can't serialize event wrapper: %w", err) + return fmt.Errorf("failed to serialize event wrapper: %w", err) } - if err := s.events.Set(ctx, wrapper.key(), wrapperData); err != nil { + if setErr := s.events.Set(ctx, wrapper.key(), wrapperData); setErr != nil { s.metrics.IncError(1) - return fmt.Errorf("can't add message to cache: %w", err) + return fmt.Errorf("failed to add message to cache: %w", setErr) } s.metrics.IncEnqueued(string(event.Type)) @@ -122,7 +124,7 @@ func (s *Service) Enqueue(token string, event types.Event) error { func (s *Service) sendAll(ctx context.Context) { rawEvents, err := s.events.Drain(ctx) if err != nil { - s.logger.Error("Can't drain cache", zap.Error(err)) + s.logger.Error("failed to drain cache", zap.Error(err)) return } @@ -134,9 +136,9 @@ func (s *Service) sendAll(ctx context.Context) { lo.Values(rawEvents), func(value []byte, _ int) (*eventWrapper, bool) { wrapper := new(eventWrapper) - if err := wrapper.deserialize(value); err != nil { + if wrapErr := wrapper.deserialize(value); wrapErr != nil { s.metrics.IncError(1) - s.logger.Error("Failed to deserialize event wrapper", zap.Binary("value", value), zap.Error(err)) + s.logger.Error("failed to deserialize event wrapper", zap.Binary("value", value), zap.Error(wrapErr)) return nil, false } @@ -146,8 +148,8 @@ func (s *Service) sendAll(ctx context.Context) { messages := lo.Map( wrappers, - func(wrapper *eventWrapper, _ int) types.Message { - return types.Message{ + func(wrapper *eventWrapper, _ int) client.Message { + return client.Message{ Token: wrapper.Token, Event: wrapper.Event, } @@ -179,8 +181,8 @@ func (s *Service) sendAll(ctx context.Context) { failed := lo.Filter( wrappers, func(item *eventWrapper, index int) bool { - if err := errs[index]; err != nil { - s.logger.Error("failed to send message", zap.String("token", item.Token), zap.Error(err)) + if sendErr := errs[index]; sendErr != nil { + s.logger.Error("failed to send message", zap.String("token", item.Token), zap.Error(sendErr)) return true } diff --git a/internal/sms-gateway/modules/push/types.go b/internal/sms-gateway/modules/push/types.go index 72ef8da0..0c662278 100644 --- a/internal/sms-gateway/modules/push/types.go +++ b/internal/sms-gateway/modules/push/types.go @@ -1,10 +1,10 @@ package push import ( - "context" "encoding/json" + "fmt" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/types" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" ) type Mode string @@ -14,13 +14,7 @@ const ( ModeUpstream Mode = "upstream" ) -type Event = types.Event - -type client interface { - Open(ctx context.Context) error - Send(ctx context.Context, messages []types.Message) ([]error, error) - Close(ctx context.Context) error -} +type Event = client.Event type eventWrapper struct { Token string `json:"token"` @@ -33,9 +27,18 @@ func (e *eventWrapper) key() string { } func (e *eventWrapper) serialize() ([]byte, error) { - return json.Marshal(e) + data, err := json.Marshal(e) + if err != nil { + return nil, fmt.Errorf("failed to marshal event: %w", err) + } + + return data, nil } func (e *eventWrapper) deserialize(data []byte) error { - return json.Unmarshal(data, e) + if err := json.Unmarshal(data, e); err != nil { + return fmt.Errorf("failed to unmarshal event: %w", err) + } + + return nil } diff --git a/internal/sms-gateway/modules/push/upstream/client.go b/internal/sms-gateway/modules/push/upstream/client.go index bb81bd18..757f880c 100644 --- a/internal/sms-gateway/modules/push/upstream/client.go +++ b/internal/sms-gateway/modules/push/upstream/client.go @@ -4,17 +4,20 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "sync" "github.com/android-sms-gateway/client-go/smsgateway" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/types" + "github.com/android-sms-gateway/server/internal/sms-gateway/modules/push/client" "github.com/samber/lo" ) -const BASE_URL = "https://api.sms-gate.app/upstream/v1" +const baseURL = "https://api.sms-gate.app/upstream/v1" + +var ErrInvalidResponse = errors.New("invalid response") type Client struct { options map[string]string @@ -26,10 +29,12 @@ type Client struct { func New(options map[string]string) (*Client, error) { return &Client{ options: options, + client: nil, + mux: sync.Mutex{}, }, nil } -func (c *Client) Open(ctx context.Context) error { +func (c *Client) Open(_ context.Context) error { c.mux.Lock() defer c.mux.Unlock() @@ -42,10 +47,10 @@ func (c *Client) Open(ctx context.Context) error { return nil } -func (c *Client) Send(ctx context.Context, messages []types.Message) ([]error, error) { +func (c *Client) Send(ctx context.Context, messages []client.Message) ([]error, error) { payload := lo.Map( messages, - func(item types.Message, _ int) smsgateway.PushNotification { + func(item client.Message, _ int) smsgateway.PushNotification { return smsgateway.PushNotification{ Token: item.Token, Event: item.Event.Type, @@ -54,15 +59,15 @@ func (c *Client) Send(ctx context.Context, messages []types.Message) ([]error, e }, ) - payloadBytes, err := json.Marshal(smsgateway.UpstreamPushRequest(payload)) + payloadBytes, err := json.Marshal(smsgateway.UpstreamPushRequest(payload)) //nolint:unconvert //type checking if err != nil { - return nil, fmt.Errorf("can't marshal payload: %w", err) + return nil, fmt.Errorf("failed to marshal payload: %w", err) } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, BASE_URL+"/push", bytes.NewReader(payloadBytes)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/push", bytes.NewReader(payloadBytes)) if err != nil { - return nil, fmt.Errorf("can't create request: %w", err) + return nil, fmt.Errorf("failed to create request: %w", err) } req.Header.Set("Content-Type", "application/json") @@ -70,7 +75,7 @@ func (c *Client) Send(ctx context.Context, messages []types.Message) ([]error, e resp, err := c.client.Do(req) if err != nil { - return nil, fmt.Errorf("can't send request: %w", err) + return nil, fmt.Errorf("failed to send request: %w", err) } defer func() { @@ -78,23 +83,26 @@ func (c *Client) Send(ctx context.Context, messages []types.Message) ([]error, e _ = resp.Body.Close() }() - if resp.StatusCode >= 400 { - return c.mapErrors(messages, fmt.Errorf("unexpected status code: %d", resp.StatusCode)), nil + if resp.StatusCode >= http.StatusBadRequest { + return c.mapErrors( + messages, + fmt.Errorf("%w: unexpected status code: %d", ErrInvalidResponse, resp.StatusCode), + ), nil } return nil, nil } -func (c *Client) mapErrors(messages []types.Message, err error) []error { +func (c *Client) mapErrors(messages []client.Message, err error) []error { return lo.Map( messages, - func(_ types.Message, _ int) error { + func(_ client.Message, _ int) error { return err }, ) } -func (c *Client) Close(ctx context.Context) error { +func (c *Client) Close(_ context.Context) error { c.mux.Lock() defer c.mux.Unlock() diff --git a/internal/sms-gateway/modules/settings/models.go b/internal/sms-gateway/modules/settings/models.go index f23f1a98..c547f64e 100644 --- a/internal/sms-gateway/modules/settings/models.go +++ b/internal/sms-gateway/modules/settings/models.go @@ -8,16 +8,24 @@ import ( ) type DeviceSettings struct { + models.TimedModel + UserID string `gorm:"primaryKey;not null;type:varchar(32)"` Settings map[string]any `gorm:"not null;type:json;serializer:json"` User models.User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` +} - models.TimedModel +func NewDeviceSettings(userID string, settings map[string]any) *DeviceSettings { + //nolint:exhaustruct // partial constructor + return &DeviceSettings{ + UserID: userID, + Settings: settings, + } } func Migrate(db *gorm.DB) error { - if err := db.AutoMigrate(&DeviceSettings{}); err != nil { + if err := db.AutoMigrate(new(DeviceSettings)); err != nil { return fmt.Errorf("device_settings migration failed: %w", err) } return nil diff --git a/internal/sms-gateway/modules/settings/module.go b/internal/sms-gateway/modules/settings/module.go index 64fb2f1f..eeb870eb 100644 --- a/internal/sms-gateway/modules/settings/module.go +++ b/internal/sms-gateway/modules/settings/module.go @@ -6,20 +6,23 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "settings", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("settings") - }), - fx.Provide( - newRepository, - fx.Private, - ), - fx.Provide( - NewService, - ), -) +func Module() fx.Option { + return fx.Module( + "settings", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("settings") + }), + fx.Provide( + newRepository, + fx.Private, + ), + fx.Provide( + NewService, + ), + ) +} +//nolint:gochecknoinits //backward compatibility func init() { db.RegisterMigration(Migrate) } diff --git a/internal/sms-gateway/modules/settings/repository.go b/internal/sms-gateway/modules/settings/repository.go index dfa1aa66..8520ddb6 100644 --- a/internal/sms-gateway/modules/settings/repository.go +++ b/internal/sms-gateway/modules/settings/repository.go @@ -1,6 +1,8 @@ package settings import ( + "fmt" + "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -11,12 +13,13 @@ type repository struct { // GetSettings retrieves the device settings for a user by their userID. func (r *repository) GetSettings(userID string) (*DeviceSettings, error) { - settings := &DeviceSettings{ - Settings: map[string]any{}, - } + settings := new(DeviceSettings) err := r.db.Where("user_id = ?", userID).Limit(1).Find(settings).Error if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get settings: %w", err) + } + if settings.Settings == nil { + settings.Settings = map[string]any{} } return settings, nil @@ -26,8 +29,8 @@ func (r *repository) GetSettings(userID string) (*DeviceSettings, error) { func (r *repository) UpdateSettings(settings *DeviceSettings) (*DeviceSettings, error) { var updatedSettings *DeviceSettings err := r.db.Transaction(func(tx *gorm.DB) error { - source := &DeviceSettings{UserID: settings.UserID} - if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Limit(1).Find(source).Error; err != nil { + source := new(DeviceSettings) + if err := tx.Clauses(clause.Locking{Strength: clause.LockingStrengthUpdate}).Where("user_id = ?", settings.UserID).Limit(1).Find(source).Error; err != nil { return err } @@ -41,7 +44,8 @@ func (r *repository) UpdateSettings(settings *DeviceSettings) (*DeviceSettings, return err } - if err := tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(settings).Error; err != nil { + err = tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(settings).Error + if err != nil { return err } @@ -49,7 +53,11 @@ func (r *repository) UpdateSettings(settings *DeviceSettings) (*DeviceSettings, updatedSettings = settings return nil }) - return updatedSettings, err + if err != nil { + return updatedSettings, fmt.Errorf("failed to update settings: %w", err) + } + + return updatedSettings, nil } // ReplaceSettings replaces the settings for a user. @@ -59,7 +67,12 @@ func (r *repository) ReplaceSettings(settings *DeviceSettings) (*DeviceSettings, err := r.db.Transaction(func(tx *gorm.DB) error { return tx.Save(settings).Error }) - return settings, err + + if err != nil { + return settings, fmt.Errorf("failed to replace settings: %w", err) + } + + return settings, nil } func newRepository(db *gorm.DB) *repository { diff --git a/internal/sms-gateway/modules/settings/service.go b/internal/sms-gateway/modules/settings/service.go index 1df4830f..ad254c2c 100644 --- a/internal/sms-gateway/modules/settings/service.go +++ b/internal/sms-gateway/modules/settings/service.go @@ -53,10 +53,7 @@ func (s *Service) UpdateSettings(userID string, settings map[string]any) (map[st return nil, err } - updatedSettings, err := s.settings.UpdateSettings(&DeviceSettings{ - UserID: userID, - Settings: filtered, - }) + updatedSettings, err := s.settings.UpdateSettings(NewDeviceSettings(userID, filtered)) if err != nil { return nil, err } @@ -72,10 +69,7 @@ func (s *Service) ReplaceSettings(userID string, settings map[string]any) (map[s return nil, err } - updated, err := s.settings.ReplaceSettings(&DeviceSettings{ - UserID: userID, - Settings: filtered, - }) + updated, err := s.settings.ReplaceSettings(NewDeviceSettings(userID, filtered)) if err != nil { return nil, err } @@ -89,7 +83,7 @@ func (s *Service) ReplaceSettings(userID string, settings map[string]any) (map[s func (s *Service) notifyDevices(userID string) { go func(userID string) { if err := s.eventsSvc.Notify(userID, nil, events.NewSettingsUpdatedEvent()); err != nil { - s.logger.Error("can't notify devices", zap.Error(err)) + s.logger.Error("failed to notify devices", zap.Error(err)) } }(userID) } diff --git a/internal/sms-gateway/modules/settings/utils.go b/internal/sms-gateway/modules/settings/utils.go index 58c0cf68..bef1d17e 100644 --- a/internal/sms-gateway/modules/settings/utils.go +++ b/internal/sms-gateway/modules/settings/utils.go @@ -1,54 +1,65 @@ package settings -import "fmt" +import ( + "errors" + "fmt" +) -var rules = map[string]any{ - "encryption": map[string]any{ - "passphrase": "", - }, - "messages": map[string]any{ - "send_interval_min": "", - "send_interval_max": "", - "limit_period": "", - "limit_value": "", - "sim_selection_mode": "", - "log_lifetime_days": "", - }, - "ping": map[string]any{ - "interval_seconds": "", - }, - "logs": map[string]any{ - "lifetime_days": "", - }, - "webhooks": map[string]any{ - "internet_required": "", - "retry_count": "", - "signing_key": "", - }, -} +var ( + ErrInvalidField = errors.New("invalid field") +) -var rulesPublic = map[string]any{ - "encryption": map[string]any{}, - "messages": map[string]any{ - "send_interval_min": "", - "send_interval_max": "", - "limit_period": "", - "limit_value": "", - "sim_selection_mode": "", - "log_lifetime_days": "", - }, - "ping": map[string]any{ - "interval_seconds": "", - }, - "logs": map[string]any{ - "lifetime_days": "", - }, - "webhooks": map[string]any{ - "internet_required": "", - "retry_count": "", - }, -} +//nolint:gochecknoglobals // private constants +var ( + rules = map[string]any{ + "encryption": map[string]any{ + "passphrase": "", + }, + "messages": map[string]any{ + "send_interval_min": "", + "send_interval_max": "", + "limit_period": "", + "limit_value": "", + "sim_selection_mode": "", + "log_lifetime_days": "", + }, + "ping": map[string]any{ + "interval_seconds": "", + }, + "logs": map[string]any{ + "lifetime_days": "", + }, + "webhooks": map[string]any{ + "internet_required": "", + "retry_count": "", + "signing_key": "", + }, + } + + rulesPublic = map[string]any{ + "encryption": map[string]any{}, + "messages": map[string]any{ + "send_interval_min": "", + "send_interval_max": "", + "limit_period": "", + "limit_value": "", + "sim_selection_mode": "", + "log_lifetime_days": "", + }, + "ping": map[string]any{ + "interval_seconds": "", + }, + "logs": map[string]any{ + "lifetime_days": "", + }, + "webhooks": map[string]any{ + "internet_required": "", + "retry_count": "", + }, + } +) +//nolint:nestif,govet // keep as is func filterMap(m map[string]any, r map[string]any) (map[string]any, error) { var err error @@ -63,7 +74,7 @@ func filterMap(m map[string]any, r map[string]any) (map[string]any, error) { } else if m[field] == nil { continue } else { - return nil, fmt.Errorf("the field: '%s' is not a map to dive", field) + return nil, fmt.Errorf("%w: '%s' is not a map to dive", ErrInvalidField, field) } } else if _, ok := rule.(string); ok { if _, ok := m[field]; !ok { @@ -76,6 +87,7 @@ func filterMap(m map[string]any, r map[string]any) (map[string]any, error) { return result, nil } +//nolint:nestif,gocognit,govet // keep as is func appendMap(m1, m2 map[string]any, rules map[string]any) (map[string]any, error) { var err error @@ -98,7 +110,7 @@ func appendMap(m1, m2 map[string]any, rules map[string]any) (map[string]any, err } else if m2[field] == nil { continue } else { - return nil, fmt.Errorf("expected field '%s' to be a map, but got %T", field, m2[field]) + return nil, fmt.Errorf("%w: expected field '%s' to be a map, but got %T", ErrInvalidField, field, m2[field]) } } else if _, ok := rule.(string); ok { if _, ok := m2[field]; !ok { diff --git a/internal/sms-gateway/modules/sse/config.go b/internal/sms-gateway/modules/sse/config.go index 3b134438..9790a60a 100644 --- a/internal/sms-gateway/modules/sse/config.go +++ b/internal/sms-gateway/modules/sse/config.go @@ -2,7 +2,7 @@ package sse import "time" -type configOption func(*Config) +type Option func(*Config) type Config struct { keepAlivePeriod time.Duration @@ -10,12 +10,14 @@ type Config struct { const defaultKeepAlivePeriod = 15 * time.Second -var defaultConfig = Config{ - keepAlivePeriod: defaultKeepAlivePeriod, +func DefaultConfig() Config { + return Config{ + keepAlivePeriod: defaultKeepAlivePeriod, + } } -func NewConfig(opts ...configOption) Config { - c := defaultConfig +func NewConfig(opts ...Option) Config { + c := DefaultConfig() for _, opt := range opts { opt(&c) @@ -28,7 +30,7 @@ func (c *Config) KeepAlivePeriod() time.Duration { return c.keepAlivePeriod } -func WithKeepAlivePeriod(d time.Duration) configOption { +func WithKeepAlivePeriod(d time.Duration) Option { if d < 0 { d = defaultKeepAlivePeriod } diff --git a/internal/sms-gateway/modules/sse/errors.go b/internal/sms-gateway/modules/sse/errors.go new file mode 100644 index 00000000..da0673b7 --- /dev/null +++ b/internal/sms-gateway/modules/sse/errors.go @@ -0,0 +1,7 @@ +package sse + +import "errors" + +var ( + ErrNoConnection = errors.New("no connection") +) diff --git a/internal/sms-gateway/modules/sse/metrics.go b/internal/sms-gateway/modules/sse/metrics.go index d939e8a3..835621a9 100644 --- a/internal/sms-gateway/modules/sse/metrics.go +++ b/internal/sms-gateway/modules/sse/metrics.go @@ -5,7 +5,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -// Metric constants +// Metric constants. const ( MetricActiveConnections = "active_connections" MetricEventsSent = "events_sent_total" @@ -22,7 +22,7 @@ const ( ErrorTypeMarshalError = "marshal_error" ) -// metrics contains all Prometheus metrics for the SSE module +// metrics contains all Prometheus metrics for the SSE module. type metrics struct { activeConnections *prometheus.GaugeVec eventsSent *prometheus.CounterVec @@ -31,7 +31,7 @@ type metrics struct { keepalivesSent *prometheus.CounterVec } -// newMetrics creates and initializes all SSE metrics +// newMetrics creates and initializes all SSE metrics. func newMetrics() *metrics { metrics := &metrics{ activeConnections: promauto.NewGaugeVec(prometheus.GaugeOpts{ diff --git a/internal/sms-gateway/modules/sse/module.go b/internal/sms-gateway/modules/sse/module.go index fb18d52f..c3b3e4a1 100644 --- a/internal/sms-gateway/modules/sse/module.go +++ b/internal/sms-gateway/modules/sse/module.go @@ -7,23 +7,26 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "sse", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("sse") - }), - fx.Provide( - newMetrics, - fx.Private, - ), - fx.Provide( - NewService, - ), - fx.Invoke(func(lc fx.Lifecycle, svc *Service) { - lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { - return svc.Close(ctx) - }, - }) - }), -) +func Module() fx.Option { + return fx.Module( + "sse", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("sse") + }), + fx.Provide( + newMetrics, + fx.Private, + ), + fx.Provide( + NewService, + ), + fx.Invoke(func(lc fx.Lifecycle, svc *Service) { + lc.Append(fx.Hook{ + OnStart: nil, + OnStop: func(ctx context.Context) error { + return svc.Close(ctx) + }, + }) + }), + ) +} diff --git a/internal/sms-gateway/modules/sse/service.go b/internal/sms-gateway/modules/sse/service.go index e5cfe5e9..071e3ba7 100644 --- a/internal/sms-gateway/modules/sse/service.go +++ b/internal/sms-gateway/modules/sse/service.go @@ -14,6 +14,10 @@ import ( "go.uber.org/zap" ) +const ( + eventsBufferSize = 8 +) + type Service struct { config Config @@ -39,6 +43,7 @@ func NewService(config Config, logger *zap.Logger, metrics *metrics) *Service { return &Service{ config: config, + mu: sync.RWMutex{}, connections: make(map[string][]*sseConnection), logger: logger, @@ -54,14 +59,14 @@ func (s *Service) Send(deviceID string, event Event) error { if !exists { // Increment connection errors metric for no connection s.metrics.IncrementConnectionErrors(ErrorTypeNoConnection) - return fmt.Errorf("no connection for device %s", deviceID) + return fmt.Errorf("%w: device %s", ErrNoConnection, deviceID) } data, err := json.Marshal(event.Data) if err != nil { // Increment connection errors metric for marshaling error s.metrics.IncrementConnectionErrors(ErrorTypeMarshalError) - return fmt.Errorf("can't marshal event: %w", err) + return fmt.Errorf("failed to marshal event: %w", err) } sent := 0 @@ -71,9 +76,17 @@ func (s *Service) Send(deviceID string, event Event) error { // Message sent successfully sent++ case <-conn.closeSignal: - s.logger.Warn("Connection closed while sending event", zap.String("device_id", deviceID), zap.String("connection_id", conn.id)) + s.logger.Warn( + "Connection closed while sending event", + zap.String("device_id", deviceID), + zap.String("connection_id", conn.id), + ) default: - s.logger.Warn("Connection buffer full while sending event", zap.String("device_id", deviceID), zap.String("connection_id", conn.id)) + s.logger.Warn( + "Connection buffer full while sending event", + zap.String("device_id", deviceID), + zap.String("connection_id", conn.id), + ) // Increment connection errors metric for buffer full s.metrics.IncrementConnectionErrors(ErrorTypeBufferFull) } @@ -82,7 +95,7 @@ func (s *Service) Send(deviceID string, event Event) error { if sent == 0 { // Increment connection errors metric for no active connection s.metrics.IncrementConnectionErrors(ErrorTypeNoConnection) - return fmt.Errorf("no active connection for device %s", deviceID) + return fmt.Errorf("%w: device %s", ErrNoConnection, deviceID) } // Count events sent @@ -111,60 +124,71 @@ func (s *Service) Handler(deviceID string, c *fiber.Ctx) error { c.Set("Transfer-Encoding", "chunked") c.Status(fiber.StatusOK).Context().SetBodyStreamWriter(func(w *bufio.Writer) { - conn := s.registerConnection(deviceID) - defer s.removeConnection(deviceID, conn.id) - - // Conditionally create ticker - var ticker *time.Ticker - if s.config.keepAlivePeriod > 0 { - ticker = time.NewTicker(s.config.keepAlivePeriod) - defer ticker.Stop() - } + s.handleStream(deviceID, w) + }) - for { - select { - case event := <-conn.channel: - s.metrics.ObserveEventDeliveryLatency(func() { - if err := s.writeToStream(w, fmt.Sprintf("event: %s\ndata: %s", event.name, utils.UnsafeString(event.data))); err != nil { - s.logger.Warn("Failed to write event data", - zap.String("device_id", deviceID), - zap.String("connection_id", conn.id), - zap.Error(err)) - return - } - }) - // Conditionally handle ticker events - case <-func() <-chan time.Time { - if ticker != nil { - return ticker.C - } - // Return nil channel that never fires when disabled - return make(chan time.Time) - }(): - if err := s.writeToStream(w, ":keepalive"); err != nil { - s.logger.Warn("Failed to write keepalive", + return nil +} + +func (s *Service) handleStream(deviceID string, w *bufio.Writer) { + conn := s.registerConnection(deviceID) + defer s.removeConnection(deviceID, conn.id) + + var tickerChan <-chan time.Time + + // Conditionally create ticker + if s.config.keepAlivePeriod > 0 { + ticker := time.NewTicker(s.config.keepAlivePeriod) + defer ticker.Stop() + + tickerChan = ticker.C + } + + for { + select { + case event := <-conn.channel: + success := true + s.metrics.ObserveEventDeliveryLatency(func() { + if err := s.writeToStream(w, fmt.Sprintf("event: %s\ndata: %s", event.name, utils.UnsafeString(event.data))); err != nil { + s.logger.Warn("failed to write event data", zap.String("device_id", deviceID), zap.String("connection_id", conn.id), zap.Error(err)) - return + success = false } - // Count keepalives sent - s.metrics.IncrementKeepalivesSent() - case <-conn.closeSignal: + }) + + if !success { + return + } + // Conditionally handle ticker events + case <-tickerChan: + if err := s.writeToStream(w, ":keepalive"); err != nil { + s.logger.Warn("failed to write keepalive", + zap.String("device_id", deviceID), + zap.String("connection_id", conn.id), + zap.Error(err)) return } + // Count keepalives sent + s.metrics.IncrementKeepalivesSent() + case <-conn.closeSignal: + return } - }) - - return nil + } } func (s *Service) writeToStream(w *bufio.Writer, data string) error { if _, err := fmt.Fprintf(w, "%s\n\n", data); err != nil { s.metrics.IncrementConnectionErrors(ErrorTypeWriteFailure) - return err + return fmt.Errorf("failed to write to stream: %w", err) + } + if err := w.Flush(); err != nil { + s.metrics.IncrementConnectionErrors(ErrorTypeWriteFailure) + return fmt.Errorf("failed to flush stream: %w", err) } - return w.Flush() + + return nil } func (s *Service) registerConnection(deviceID string) *sseConnection { @@ -175,7 +199,7 @@ func (s *Service) registerConnection(deviceID string) *sseConnection { conn := &sseConnection{ id: connID, - channel: make(chan eventWrapper, 8), + channel: make(chan eventWrapper, eventsBufferSize), closeSignal: make(chan struct{}), } @@ -202,7 +226,11 @@ func (s *Service) removeConnection(deviceID, connID string) { if conn.id == connID { close(conn.closeSignal) s.connections[deviceID] = append(connections[:i], connections[i+1:]...) - s.logger.Info("Removing SSE connection", zap.String("device_id", deviceID), zap.String("connection_id", connID)) + s.logger.Info( + "Removing SSE connection", + zap.String("device_id", deviceID), + zap.String("connection_id", connID), + ) break } } diff --git a/internal/sms-gateway/modules/webhooks/errors.go b/internal/sms-gateway/modules/webhooks/errors.go index a599d987..8e4b014f 100644 --- a/internal/sms-gateway/modules/webhooks/errors.go +++ b/internal/sms-gateway/modules/webhooks/errors.go @@ -1,6 +1,13 @@ package webhooks -import "fmt" +import ( + "errors" + "fmt" +) + +var ( + ErrInvalidEvent = errors.New("invalid event") +) type ValidationError struct { Field string @@ -9,7 +16,7 @@ type ValidationError struct { } func (e ValidationError) Error() string { - return fmt.Sprintf("invalid `%s` = `%s`: %s", e.Field, e.Value, e.Err) + return fmt.Sprintf("invalid %q = %q: %s", e.Field, e.Value, e.Err) } func (e ValidationError) Unwrap() error { @@ -25,6 +32,5 @@ func newValidationError(field, value string, err error) ValidationError { } func IsValidationError(err error) bool { - _, ok := err.(ValidationError) - return ok + return errors.As(err, new(ValidationError)) } diff --git a/internal/sms-gateway/modules/webhooks/models.go b/internal/sms-gateway/modules/webhooks/models.go index a6a546ae..806bdd0d 100644 --- a/internal/sms-gateway/modules/webhooks/models.go +++ b/internal/sms-gateway/modules/webhooks/models.go @@ -1,27 +1,43 @@ package webhooks import ( + "fmt" + "github.com/android-sms-gateway/client-go/smsgateway" "github.com/android-sms-gateway/server/internal/sms-gateway/models" "gorm.io/gorm" ) type Webhook struct { - ID uint64 `json:"-" gorm:"->;primaryKey;type:BIGINT UNSIGNED;autoIncrement"` - ExtID string `json:"id" gorm:"not null;type:varchar(36);uniqueIndex:unq_webhooks_user_extid,priority:2"` - UserID string `json:"-" gorm:"<-:create;not null;type:varchar(32);uniqueIndex:unq_webhooks_user_extid,priority:1"` + models.SoftDeletableModel + + ID uint64 `json:"-" gorm:"->;primaryKey;type:BIGINT UNSIGNED;autoIncrement"` + ExtID string `json:"id" gorm:"not null;type:varchar(36);uniqueIndex:unq_webhooks_user_extid,priority:2"` + UserID string `json:"-" gorm:"<-:create;not null;type:varchar(32);uniqueIndex:unq_webhooks_user_extid,priority:1"` DeviceID *string `json:"device_id,omitempty" gorm:"type:varchar(21);index:idx_webhooks_device"` - URL string `json:"url" validate:"required,http_url" gorm:"not null;type:varchar(256)"` - Event smsgateway.WebhookEvent `json:"event" gorm:"not null;type:varchar(32)"` + URL string `json:"url" validate:"required,http_url" gorm:"not null;type:varchar(256)"` + Event smsgateway.WebhookEvent `json:"event" gorm:"not null;type:varchar(32)"` User models.User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` Device *models.Device `gorm:"foreignKey:DeviceID;constraint:OnDelete:CASCADE"` +} - models.SoftDeletableModel +func newWebhook(extID string, url string, event smsgateway.WebhookEvent, userID string, deviceID *string) *Webhook { + //nolint:exhaustruct // partial constructor + return &Webhook{ + ExtID: extID, + URL: url, + Event: event, + UserID: userID, + DeviceID: deviceID, + } } func Migrate(db *gorm.DB) error { - return db.AutoMigrate(&Webhook{}) + if err := db.AutoMigrate(new(Webhook)); err != nil { + return fmt.Errorf("webhooks migration failed: %w", err) + } + return nil } diff --git a/internal/sms-gateway/modules/webhooks/module.go b/internal/sms-gateway/modules/webhooks/module.go index 709c57ba..6bdd73d1 100644 --- a/internal/sms-gateway/modules/webhooks/module.go +++ b/internal/sms-gateway/modules/webhooks/module.go @@ -6,17 +6,20 @@ import ( "go.uber.org/zap" ) -var Module = fx.Module( - "webhooks", - fx.Decorate(func(log *zap.Logger) *zap.Logger { - return log.Named("webhooks") - }), - fx.Provide(NewRepository, fx.Private), - fx.Provide( - NewService, - ), -) +func Module() fx.Option { + return fx.Module( + "webhooks", + fx.Decorate(func(log *zap.Logger) *zap.Logger { + return log.Named("webhooks") + }), + fx.Provide(NewRepository, fx.Private), + fx.Provide( + NewService, + ), + ) +} +//nolint:gochecknoinits //framework-specific func init() { db.RegisterMigration(Migrate) } diff --git a/internal/sms-gateway/modules/webhooks/repository.go b/internal/sms-gateway/modules/webhooks/repository.go index d0f63f10..2c31c4da 100644 --- a/internal/sms-gateway/modules/webhooks/repository.go +++ b/internal/sms-gateway/modules/webhooks/repository.go @@ -9,6 +9,12 @@ type Repository struct { db *gorm.DB } +func NewRepository(db *gorm.DB) *Repository { + return &Repository{ + db: db, + } +} + func (r *Repository) Select(filters ...SelectFilter) ([]*Webhook, error) { webhooks := []*Webhook{} if err := newFilter(filters...).apply(r.db).Find(&webhooks).Error; err != nil { @@ -25,11 +31,5 @@ func (r *Repository) Replace(webhook *Webhook) error { } func (r *Repository) Delete(filters ...SelectFilter) error { - return newFilter(filters...).apply(r.db).Delete(&Webhook{}).Error -} - -func NewRepository(db *gorm.DB) *Repository { - return &Repository{ - db: db, - } + return newFilter(filters...).apply(r.db).Delete(new(Webhook)).Error } diff --git a/internal/sms-gateway/modules/webhooks/repository_filter.go b/internal/sms-gateway/modules/webhooks/repository_filter.go index 87301a7a..9750e8af 100644 --- a/internal/sms-gateway/modules/webhooks/repository_filter.go +++ b/internal/sms-gateway/modules/webhooks/repository_filter.go @@ -24,7 +24,7 @@ type selectFilter struct { } func newFilter(filters ...SelectFilter) *selectFilter { - f := &selectFilter{} + f := new(selectFilter) f.merge(filters...) return f } diff --git a/internal/sms-gateway/modules/webhooks/service.go b/internal/sms-gateway/modules/webhooks/service.go index 0e902bd4..28dc5b39 100644 --- a/internal/sms-gateway/modules/webhooks/service.go +++ b/internal/sms-gateway/modules/webhooks/service.go @@ -53,7 +53,7 @@ func NewService(params ServiceParams) *Service { func (s *Service) _select(filters ...SelectFilter) ([]smsgateway.Webhook, error) { items, err := s.webhooks.Select(filters...) if err != nil { - return nil, fmt.Errorf("can't select webhooks: %w", err) + return nil, fmt.Errorf("failed to select webhooks: %w", err) } return slices.Map(items, webhookToDTO), nil @@ -69,9 +69,9 @@ func (s *Service) Select(userID string, filters ...SelectFilter) ([]smsgateway.W // Replace creates or updates a webhook for a given user. After replacing the webhook, // it asynchronously notifies all the user's devices. Returns an error if the operation fails. -func (s *Service) Replace(userID string, webhook smsgateway.Webhook) error { +func (s *Service) Replace(userID string, webhook *smsgateway.Webhook) error { if !smsgateway.IsValidWebhookEvent(webhook.Event) { - return newValidationError("event", string(webhook.Event), fmt.Errorf("enum value expected")) + return newValidationError("event", webhook.Event, ErrInvalidEvent) } if webhook.ID == "" { @@ -82,23 +82,23 @@ func (s *Service) Replace(userID string, webhook smsgateway.Webhook) error { if webhook.DeviceID != nil { ok, err := s.devicesSvc.Exists(userID, devices.WithID(*webhook.DeviceID)) if err != nil { - return fmt.Errorf("failed to select devices: %w", err) + return fmt.Errorf("failed to verify device ownership: %w", err) } if !ok { return newValidationError("device_id", *webhook.DeviceID, devices.ErrNotFound) } } - model := Webhook{ - ExtID: webhook.ID, - UserID: userID, - DeviceID: webhook.DeviceID, - URL: webhook.URL, - Event: webhook.Event, - } + model := newWebhook( + webhook.ID, + webhook.URL, + webhook.Event, + userID, + webhook.DeviceID, + ) - if err := s.webhooks.Replace(&model); err != nil { - return fmt.Errorf("can't replace webhook: %w", err) + if err := s.webhooks.Replace(model); err != nil { + return fmt.Errorf("failed to replace webhook: %w", err) } s.notifyDevices(userID, webhook.DeviceID) @@ -111,7 +111,7 @@ func (s *Service) Replace(userID string, webhook smsgateway.Webhook) error { func (s *Service) Delete(userID string, filters ...SelectFilter) error { filters = append(filters, WithUserID(userID)) if err := s.webhooks.Delete(filters...); err != nil { - return fmt.Errorf("can't delete webhooks: %w", err) + return fmt.Errorf("failed to delete webhooks: %w", err) } s.notifyDevices(userID, nil) @@ -123,7 +123,7 @@ func (s *Service) Delete(userID string, filters ...SelectFilter) error { func (s *Service) notifyDevices(userID string, deviceID *string) { go func(userID string, deviceID *string) { if err := s.eventsSvc.Notify(userID, deviceID, events.NewWebhooksUpdatedEvent()); err != nil { - s.logger.Error("can't notify devices", zap.Error(err)) + s.logger.Error("failed to notify devices", zap.Error(err)) } }(userID, deviceID) } diff --git a/internal/sms-gateway/online/metrics.go b/internal/sms-gateway/online/metrics.go index cc556f26..61d2ce3c 100644 --- a/internal/sms-gateway/online/metrics.go +++ b/internal/sms-gateway/online/metrics.go @@ -5,7 +5,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promauto" ) -// Metric constants +// Metric constants. const ( metricStatusSetTotal = "status_set_total" metricCacheOperations = "cache_operations_total" @@ -24,7 +24,7 @@ const ( statusError = "error" ) -// metrics contains all Prometheus metrics for the online module +// metrics contains all Prometheus metrics for the online module. type metrics struct { statusSetCounter *prometheus.CounterVec cacheOperations *prometheus.CounterVec @@ -34,7 +34,7 @@ type metrics struct { batchSize prometheus.Gauge } -// newMetrics creates and initializes all online metrics +// newMetrics creates and initializes all online metrics. func newMetrics() *metrics { return &metrics{ statusSetCounter: promauto.NewCounterVec(prometheus.CounterOpts{ @@ -83,7 +83,7 @@ func newMetrics() *metrics { } } -// IncrementStatusSet increments the status set counter +// IncrementStatusSet increments the status set counter. func (m *metrics) IncrementStatusSet(success bool) { status := statusSuccess if !success { @@ -92,31 +92,31 @@ func (m *metrics) IncrementStatusSet(success bool) { m.statusSetCounter.WithLabelValues(status).Inc() } -// IncrementCacheOperation increments cache operation counter +// IncrementCacheOperation increments cache operation counter. func (m *metrics) IncrementCacheOperation(operation, status string) { m.cacheOperations.WithLabelValues(operation, status).Inc() } -// ObserveCacheLatency observes cache operation latency +// ObserveCacheLatency observes cache operation latency. func (m *metrics) ObserveCacheLatency(f func()) { timer := prometheus.NewTimer(m.cacheLatency) f() timer.ObserveDuration() } -// ObservePersistenceLatency observes persistence operation latency +// ObservePersistenceLatency observes persistence operation latency. func (m *metrics) ObservePersistenceLatency(f func()) { timer := prometheus.NewTimer(m.persistenceLatency) f() timer.ObserveDuration() } -// IncrementPersistenceError increments persistence error counter +// IncrementPersistenceError increments persistence error counter. func (m *metrics) IncrementPersistenceError() { m.persistenceErrors.Inc() } -// SetBatchSize sets the current batch size +// SetBatchSize sets the current batch size. func (m *metrics) SetBatchSize(size int) { m.batchSize.Set(float64(size)) } diff --git a/internal/sms-gateway/online/service.go b/internal/sms-gateway/online/service.go index e8560631..bdc6d289 100644 --- a/internal/sms-gateway/online/service.go +++ b/internal/sms-gateway/online/service.go @@ -47,7 +47,7 @@ func (s *service) Run(ctx context.Context) { case <-ticker.C: s.logger.Debug("Persisting online status") if err := s.persist(ctx); err != nil { - s.logger.Error("Can't persist online status", zap.Error(err)) + s.logger.Error("failed to persist online status", zap.Error(err)) } } } @@ -62,7 +62,7 @@ func (s *service) SetOnline(ctx context.Context, deviceID string) { s.metrics.ObserveCacheLatency(func() { if err = s.cache.Set(ctx, deviceID, []byte(dt)); err != nil { s.metrics.IncrementCacheOperation(operationSet, statusError) - s.logger.Error("Can't set online status", zap.String("device_id", deviceID), zap.Error(err)) + s.logger.Error("failed to set online status", zap.String("device_id", deviceID), zap.Error(err)) s.metrics.IncrementStatusSet(false) } }) @@ -82,7 +82,7 @@ func (s *service) persist(ctx context.Context) error { s.metrics.ObservePersistenceLatency(func() { items, err := s.cache.Drain(ctx) if err != nil { - drainErr = fmt.Errorf("can't drain cache: %w", err) + drainErr = fmt.Errorf("failed to drain cache: %w", err) s.metrics.IncrementCacheOperation(operationDrain, statusError) return } @@ -96,9 +96,9 @@ func (s *service) persist(ctx context.Context) error { s.logger.Debug("Drained cache", zap.Int("count", len(items))) timestamps := maps.MapValues(items, func(v []byte) time.Time { - t, err := time.Parse(time.RFC3339, string(v)) - if err != nil { - s.logger.Warn("Can't parse last seen", zap.String("last_seen", string(v)), zap.Error(err)) + t, parseErr := time.Parse(time.RFC3339, string(v)) + if parseErr != nil { + s.logger.Warn("failed to parse last seen", zap.String("last_seen", string(v)), zap.Error(parseErr)) return time.Now().UTC() } @@ -107,8 +107,8 @@ func (s *service) persist(ctx context.Context) error { s.logger.Debug("Parsed last seen timestamps", zap.Int("count", len(timestamps))) - if err := s.devicesSvc.SetLastSeen(ctx, timestamps); err != nil { - persistErr = fmt.Errorf("can't set last seen: %w", err) + if seenErr := s.devicesSvc.SetLastSeen(ctx, timestamps); seenErr != nil { + persistErr = fmt.Errorf("failed to set last seen: %w", seenErr) s.metrics.IncrementPersistenceError() return } diff --git a/internal/sms-gateway/openapi/docs.go b/internal/sms-gateway/openapi/docs.go index d86627b8..69a3f69c 100644 --- a/internal/sms-gateway/openapi/docs.go +++ b/internal/sms-gateway/openapi/docs.go @@ -123,23 +123,23 @@ const docTemplate = `{ }, "/3rdparty/v1/health": { "get": { - "description": "Checks if service is healthy", + "description": "Checks if service is ready to serve traffic (readiness probe)", "produces": [ "application/json" ], "tags": [ "System" ], - "summary": "Health check", + "summary": "Readiness probe", "responses": { "200": { - "description": "Health check result", + "description": "Service is ready", "schema": { "$ref": "#/definitions/smsgateway.HealthResponse" } }, - "500": { - "description": "Service is unhealthy", + "503": { + "description": "Service is not ready", "schema": { "$ref": "#/definitions/smsgateway.HealthResponse" } @@ -785,6 +785,84 @@ const docTemplate = `{ } } } + }, + "/health/live": { + "get": { + "description": "Checks if service is running (liveness probe)", + "produces": [ + "application/json" + ], + "tags": [ + "System" + ], + "summary": "Liveness probe", + "responses": { + "200": { + "description": "Service is alive", + "schema": { + "$ref": "#/definitions/smsgateway.HealthResponse" + } + }, + "503": { + "description": "Service is not alive", + "schema": { + "$ref": "#/definitions/smsgateway.HealthResponse" + } + } + } + } + }, + "/health/ready": { + "get": { + "description": "Checks if service is ready to serve traffic (readiness probe)", + "produces": [ + "application/json" + ], + "tags": [ + "System" + ], + "summary": "Readiness probe", + "responses": { + "200": { + "description": "Service is ready", + "schema": { + "$ref": "#/definitions/smsgateway.HealthResponse" + } + }, + "503": { + "description": "Service is not ready", + "schema": { + "$ref": "#/definitions/smsgateway.HealthResponse" + } + } + } + } + }, + "/health/startup": { + "get": { + "description": "Checks if service has completed initialization (startup probe)", + "produces": [ + "application/json" + ], + "tags": [ + "System" + ], + "summary": "Startup probe", + "responses": { + "200": { + "description": "Service has completed initialization", + "schema": { + "$ref": "#/definitions/smsgateway.HealthResponse" + } + }, + "503": { + "description": "Service has not completed initialization", + "schema": { + "$ref": "#/definitions/smsgateway.HealthResponse" + } + } + } + } } }, "definitions": { @@ -1200,6 +1278,7 @@ const docTemplate = `{ }, "smsgateway.MessagePriority": { "type": "integer", + "format": "int32", "enum": [ -128, 0, @@ -1209,6 +1288,12 @@ const docTemplate = `{ "x-enum-comments": { "PriorityBypassThreshold": "Threshold at which messages bypass limits and delays" }, + "x-enum-descriptions": [ + "", + "", + "Threshold at which messages bypass limits and delays", + "" + ], "x-enum-varnames": [ "PriorityMinimum", "PriorityDefault", @@ -1326,6 +1411,13 @@ const docTemplate = `{ "ProcessingStateProcessed": "Processed (received by device)", "ProcessingStateSent": "Sent" }, + "x-enum-descriptions": [ + "Pending", + "Processed (received by device)", + "Sent", + "Delivered", + "Failed" + ], "x-enum-varnames": [ "ProcessingStatePending", "ProcessingStateProcessed", @@ -1591,7 +1683,7 @@ var SwaggerInfo = &swag.Spec{ Host: "api.sms-gate.app", BasePath: "", Schemes: []string{"https"}, - Title: "SMS Gateway for Androidβ„’ API", + Title: "SMSGate API", Description: "This API provides programmatic access to sending SMS messages on Android devices. Features include sending SMS, checking message status, device management, webhook configuration, and system health checks.", InfoInstanceName: "swagger", SwaggerTemplate: docTemplate, diff --git a/internal/sms-gateway/pubsub/module.go b/internal/sms-gateway/pubsub/module.go index 4f3d0bca..2f0b1dca 100644 --- a/internal/sms-gateway/pubsub/module.go +++ b/internal/sms-gateway/pubsub/module.go @@ -2,6 +2,7 @@ package pubsub import ( "context" + "fmt" "go.uber.org/fx" "go.uber.org/zap" @@ -16,10 +17,13 @@ func Module() fx.Option { fx.Provide(New), fx.Invoke(func(ps PubSub, logger *zap.Logger, lc fx.Lifecycle) { lc.Append(fx.Hook{ + OnStart: func(_ context.Context) error { + return nil + }, OnStop: func(_ context.Context) error { if err := ps.Close(); err != nil { logger.Error("pubsub close failed", zap.Error(err)) - return err + return fmt.Errorf("failed to close pubsub: %w", err) } return nil }, diff --git a/internal/sms-gateway/pubsub/pubsub.go b/internal/sms-gateway/pubsub/pubsub.go index 84ca2e9f..8f6cdc2e 100644 --- a/internal/sms-gateway/pubsub/pubsub.go +++ b/internal/sms-gateway/pubsub/pubsub.go @@ -1,6 +1,7 @@ package pubsub import ( + "errors" "fmt" "net/url" @@ -13,6 +14,8 @@ const ( type PubSub = pubsub.PubSub +var ErrInvalidScheme = errors.New("invalid scheme") + func New(config Config) (PubSub, error) { if config.URL == "" { config.URL = "memory://" @@ -20,22 +23,29 @@ func New(config Config) (PubSub, error) { u, err := url.Parse(config.URL) if err != nil { - return nil, fmt.Errorf("can't parse url: %w", err) + return nil, fmt.Errorf("failed to parse url: %w", err) } opts := []pubsub.Option{} opts = append(opts, pubsub.WithBufferSize(config.BufferSize)) + var pubSub PubSub switch u.Scheme { case "memory": - return pubsub.NewMemory(opts...), nil + pubSub, err = pubsub.NewMemory(opts...), nil case "redis": - return pubsub.NewRedis(pubsub.RedisConfig{ + pubSub, err = pubsub.NewRedis(pubsub.RedisConfig{ Client: nil, URL: config.URL, Prefix: topicPrefix, }, opts...) default: - return nil, fmt.Errorf("invalid scheme: %s", u.Scheme) + return nil, fmt.Errorf("%w: %s", ErrInvalidScheme, u.Scheme) } + + if err != nil { + return nil, fmt.Errorf("failed to create pubsub: %w", err) + } + + return pubSub, nil } diff --git a/internal/version/version.go b/internal/version/version.go index 8d5a9326..504a2153 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -4,7 +4,9 @@ import "strconv" const notSet string = "not set" -// these information will be collected when build, by `-ldflags "-X main.appVersion=0.1"` +// This information will be collected when build, by `-ldflags "-X main.appVersion=0.1"`. +// +//nolint:gochecknoglobals // build-time constant var ( AppVersion = notSet AppRelease = notSet diff --git a/internal/worker/app.go b/internal/worker/app.go index 36b2ae18..c75842ff 100644 --- a/internal/worker/app.go +++ b/internal/worker/app.go @@ -37,11 +37,11 @@ func module() fx.Option { server.Module(), fx.Invoke(func(logger *zap.Logger, lc fx.Lifecycle) { lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { + OnStart: func(_ context.Context) error { logger.Info("worker started") return nil }, - OnStop: func(ctx context.Context) error { + OnStop: func(_ context.Context) error { logger.Info("worker stopped") return nil }, diff --git a/internal/worker/config/config.go b/internal/worker/config/config.go index 574d3f5f..83c2d292 100644 --- a/internal/worker/config/config.go +++ b/internal/worker/config/config.go @@ -32,6 +32,7 @@ type DevicesCleanup struct { } func Default() Config { + //nolint:exhaustruct,mnd // default values return Config{ Tasks: Tasks{ MessagesHashing: MessagesHashing{ diff --git a/internal/worker/config/types.go b/internal/worker/config/types.go index 92ccfe2f..c22952fc 100644 --- a/internal/worker/config/types.go +++ b/internal/worker/config/types.go @@ -13,7 +13,7 @@ type Duration time.Duration func (d *Duration) UnmarshalText(text []byte) error { t, err := time.ParseDuration(string(text)) if err != nil { - return fmt.Errorf("can't parse duration: %w", err) + return fmt.Errorf("failed to parse duration: %w", err) } *d = Duration(t) return nil @@ -22,12 +22,12 @@ func (d *Duration) UnmarshalText(text []byte) error { func (d *Duration) UnmarshalYAML(value *yaml.Node) error { var s string if err := value.Decode(&s); err != nil { - return fmt.Errorf("can't unmarshal duration: %w", err) + return fmt.Errorf("failed to unmarshal duration: %w", err) } t, err := time.ParseDuration(s) if err != nil { - return fmt.Errorf("can't parse duration: %w", err) + return fmt.Errorf("failed to parse duration: %w", err) } *d = Duration(t) return nil diff --git a/internal/worker/executor/metrics.go b/internal/worker/executor/metrics.go index 8d4b9e6e..52ddb298 100644 --- a/internal/worker/executor/metrics.go +++ b/internal/worker/executor/metrics.go @@ -26,16 +26,19 @@ func newMetrics() *metrics { Namespace: "worker", Subsystem: "executor", Name: "active_tasks", + Help: "Number of active tasks", }), taskResult: promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "worker", Subsystem: "executor", Name: "task_result_total", + Help: "Task result, labeled by task name and result", }, []string{"task", "result"}), taskDuration: promauto.NewHistogramVec(prometheus.HistogramOpts{ Namespace: "worker", Subsystem: "executor", Name: "task_duration_seconds", + Help: "Task duration in seconds", Buckets: []float64{.001, .005, .01, .025, .05, .1, .25, .5, 1, 2.5, 5, 10}, }, []string{"task"}), } diff --git a/internal/worker/executor/service.go b/internal/worker/executor/service.go index b1c9cc16..47cb24a7 100644 --- a/internal/worker/executor/service.go +++ b/internal/worker/executor/service.go @@ -54,7 +54,12 @@ func (s *Service) Start() error { s.wg.Add(1) go func(index int, task PeriodicTask) { defer s.wg.Done() - s.logger.Info("starting task", zap.Int("index", index), zap.String("name", task.Name()), zap.Duration("interval", task.Interval())) + s.logger.Info( + "starting task", + zap.Int("index", index), + zap.String("name", task.Name()), + zap.Duration("interval", task.Interval()), + ) s.runTask(ctx, task) s.logger.Info("task stopped", zap.Int("index", index), zap.String("name", task.Name())) }(index, task) @@ -64,6 +69,7 @@ func (s *Service) Start() error { } func (s *Service) runTask(ctx context.Context, task PeriodicTask) { + //nolint:gosec // weak random is acceptable for scheduling jitter initialDelay := time.Duration(math.Floor(rand.Float64()*task.Interval().Seconds())) * time.Second s.logger.Info("initial delay", zap.String("name", task.Name()), zap.Duration("delay", initialDelay)) @@ -94,12 +100,12 @@ func (s *Service) execute(ctx context.Context, task PeriodicTask) { logger := s.logger.With(zap.String("name", task.Name())) if err := s.locker.AcquireLock(ctx, task.Name()); err != nil { - logger.Error("can't acquire lock", zap.String("name", task.Name()), zap.Error(err)) + logger.Error("failed to acquire lock", zap.String("name", task.Name()), zap.Error(err)) return } defer func() { if err := s.locker.ReleaseLock(ctx, task.Name()); err != nil { - logger.Error("can't release lock", zap.String("name", task.Name()), zap.Error(err)) + logger.Error("failed to release lock", zap.String("name", task.Name()), zap.Error(err)) } }() diff --git a/internal/worker/locker/mysql.go b/internal/worker/locker/mysql.go index 56ccdebe..969c6b0c 100644 --- a/internal/worker/locker/mysql.go +++ b/internal/worker/locker/mysql.go @@ -26,6 +26,7 @@ func NewMySQLLocker(db *sql.DB, prefix string, timeout time.Duration) Locker { prefix: prefix, timeout: timeout, + mu: sync.Mutex{}, conns: make(map[string]*sql.Conn), } } @@ -41,9 +42,9 @@ func (m *mySQLLocker) AcquireLock(ctx context.Context, key string) error { } var res sql.NullInt64 - if err := conn.QueryRowContext(ctx, "SELECT GET_LOCK(?, ?)", name, m.timeout.Seconds()).Scan(&res); err != nil { + if lockErr := conn.QueryRowContext(ctx, "SELECT GET_LOCK(?, ?)", name, m.timeout.Seconds()).Scan(&res); lockErr != nil { _ = conn.Close() - return fmt.Errorf("failed to get lock: %w", err) + return fmt.Errorf("failed to get lock: %w", lockErr) } if !res.Valid || res.Int64 != 1 { _ = conn.Close() @@ -70,7 +71,7 @@ func (m *mySQLLocker) ReleaseLock(ctx context.Context, key string) error { delete(m.conns, key) m.mu.Unlock() if conn == nil { - return fmt.Errorf("no held connection for key %q", key) + return fmt.Errorf("%w: no held connection for key %q", ErrLockNotAcquired, key) } var result sql.NullInt64 @@ -81,7 +82,7 @@ func (m *mySQLLocker) ReleaseLock(ctx context.Context, key string) error { return fmt.Errorf("failed to release lock: %w", err) } if !result.Valid || result.Int64 != 1 { - return fmt.Errorf("lock was not held or doesn't exist") + return fmt.Errorf("%w: lock was not held or doesn't exist", ErrLockNotAcquired) } return nil From 601a2e24a0090703c4e1fcfce244431d4fe49a94 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 18 Nov 2025 11:44:45 +0700 Subject: [PATCH 10/13] [tests] fix formatting --- test/e2e/mobile_test.go | 4 +++- test/e2e/utils_test.go | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/test/e2e/mobile_test.go b/test/e2e/mobile_test.go index b408ad1c..d7ebbc6e 100644 --- a/test/e2e/mobile_test.go +++ b/test/e2e/mobile_test.go @@ -190,7 +190,9 @@ func TestPublicDeviceRegisterWithCredentials(t *testing.T) { { name: "Valid Credentials", headers: map[string]string{ - "Authorization": "Basic " + base64.StdEncoding.EncodeToString([]byte(firstDevice.Login+":"+firstDevice.Password)), + "Authorization": "Basic " + base64.StdEncoding.EncodeToString( + []byte(firstDevice.Login+":"+firstDevice.Password), + ), }, expectedStatusCode: 201, expectedLogin: "", diff --git a/test/e2e/utils_test.go b/test/e2e/utils_test.go index 31a6cfcc..6e24d198 100644 --- a/test/e2e/utils_test.go +++ b/test/e2e/utils_test.go @@ -18,7 +18,11 @@ func (o *mobileDeviceRegisterOptions) withCredentials(username, password string) return o } -func mobileDeviceRegister(t *testing.T, client *resty.Client, opts ...*mobileDeviceRegisterOptions) mobileRegisterResponse { +func mobileDeviceRegister( + t *testing.T, + client *resty.Client, + opts ...*mobileDeviceRegisterOptions, +) mobileRegisterResponse { req := client.R() for _, opt := range opts { if opt.username != "" && opt.password != "" { From 7804f8b74c94d9522acf46d8d90e6a8ccc7a4752 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Thu, 20 Nov 2025 19:26:08 +0700 Subject: [PATCH 11/13] [sse] fix active connections metric --- internal/sms-gateway/modules/sse/service.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/sms-gateway/modules/sse/service.go b/internal/sms-gateway/modules/sse/service.go index 071e3ba7..57aa07c5 100644 --- a/internal/sms-gateway/modules/sse/service.go +++ b/internal/sms-gateway/modules/sse/service.go @@ -226,6 +226,9 @@ func (s *Service) removeConnection(deviceID, connID string) { if conn.id == connID { close(conn.closeSignal) s.connections[deviceID] = append(connections[:i], connections[i+1:]...) + + // Decrement active connections metric + s.metrics.DecrementActiveConnections() s.logger.Info( "Removing SSE connection", zap.String("device_id", deviceID), @@ -235,9 +238,6 @@ func (s *Service) removeConnection(deviceID, connID string) { } } - // Decrement active connections metric - s.metrics.DecrementActiveConnections() - if len(s.connections[deviceID]) == 0 { delete(s.connections, deviceID) } From 4ff63d0309ea28ac3f6d15f23353ba31ccc0c77f Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Thu, 20 Nov 2025 19:26:23 +0700 Subject: [PATCH 12/13] [executor] optimize logging --- internal/worker/executor/service.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/worker/executor/service.go b/internal/worker/executor/service.go index 47cb24a7..81feba67 100644 --- a/internal/worker/executor/service.go +++ b/internal/worker/executor/service.go @@ -100,12 +100,12 @@ func (s *Service) execute(ctx context.Context, task PeriodicTask) { logger := s.logger.With(zap.String("name", task.Name())) if err := s.locker.AcquireLock(ctx, task.Name()); err != nil { - logger.Error("failed to acquire lock", zap.String("name", task.Name()), zap.Error(err)) + logger.Error("failed to acquire lock", zap.Error(err)) return } defer func() { if err := s.locker.ReleaseLock(ctx, task.Name()); err != nil { - logger.Error("failed to release lock", zap.String("name", task.Name()), zap.Error(err)) + logger.Error("failed to release lock", zap.Error(err)) } }() From 71fb64f26347f1d33609b6ae7947f555745f8588 Mon Sep 17 00:00:00 2001 From: Aleksandr Soloshenko Date: Tue, 25 Nov 2025 12:29:59 +0700 Subject: [PATCH 13/13] [e2e] add 3rdparty webhooks API tests --- test/e2e/webhooks_test.go | 625 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 625 insertions(+) create mode 100644 test/e2e/webhooks_test.go diff --git a/test/e2e/webhooks_test.go b/test/e2e/webhooks_test.go new file mode 100644 index 00000000..4e3f22a7 --- /dev/null +++ b/test/e2e/webhooks_test.go @@ -0,0 +1,625 @@ +package e2e + +import ( + "encoding/json" + "testing" + + "github.com/go-resty/resty/v2" +) + +type webhook struct { + ID string `json:"id"` + DeviceID string `json:"deviceId,omitempty"` + URL string `json:"url"` + Event string `json:"event"` +} + +func TestWebhooks_Get(t *testing.T) { + credentials := mobileDeviceRegister(t, publicMobileClient) + authorizedClient := publicUserClient.Clone().SetBasicAuth(credentials.Login, credentials.Password) + + cases := []struct { + name string + setup func() + expectedStatusCode int + request func() *resty.Request + validate func(t *testing.T, response *resty.Response) + }{ + { + name: "Successful retrieval of empty webhook list", + setup: func() { + // Start with empty webhook list + }, + expectedStatusCode: 200, + request: func() *resty.Request { + return authorizedClient.R() + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 200 { + t.Fatal(response.StatusCode(), response.String()) + } + + var result []webhook + if err := json.Unmarshal(response.Body(), &result); err != nil { + t.Fatal(err) + } + + // Verify response structure + if len(result) != 0 { + t.Errorf("expected empty webhook list, got %d webhooks", len(result)) + } + + // Verify response headers + if response.Header().Get("Content-Type") != "application/json" { + t.Error("expected Content-Type to be application/json") + } + }, + }, + { + name: "List webhooks after creation", + setup: func() { + // Create a webhook first + _, err := authorizedClient.R(). + SetBody(webhook{ + URL: "https://example.com/list-test", + Event: "sms:delivered", + }).Post("webhooks") + if err != nil { + t.Fatal(err) + } + }, + expectedStatusCode: 200, + request: func() *resty.Request { + return authorizedClient.R() + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 200 { + t.Fatal(response.StatusCode(), response.String()) + } + + var result []webhook + if err := json.Unmarshal(response.Body(), &result); err != nil { + t.Fatal(err) + } + + // Verify response structure + if len(result) == 0 { + t.Error("expected webhook list to contain created webhooks") + } + + // Verify webhook structure + for _, webhook := range result { + if webhook.ID == "" { + t.Error("webhook ID is empty") + } + if webhook.URL == "" { + t.Error("webhook URL is empty") + } + if webhook.Event == "" { + t.Error("webhook event is empty") + } + } + }, + }, + { + name: "Missing authentication", + setup: func() { + // No setup needed + }, + expectedStatusCode: 401, + request: func() *resty.Request { + return publicUserClient.R() + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 401 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Invalid credentials", + setup: func() { + // No setup needed + }, + expectedStatusCode: 401, + request: func() *resty.Request { + return publicUserClient.R().SetBasicAuth("invalid", "credentials") + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 401 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + c.setup() + + res, err := c.request().Get("webhooks") + if err != nil { + t.Fatal(err) + } + + if res.StatusCode() != c.expectedStatusCode { + t.Fatal(res.StatusCode(), res.String()) + } + + if c.validate != nil { + c.validate(t, res) + } + }) + } +} + +func TestWebhooks_Post(t *testing.T) { + credentials := mobileDeviceRegister(t, publicMobileClient) + authorizedClient := publicUserClient.Clone().SetBasicAuth(credentials.Login, credentials.Password) + + cases := []struct { + name string + setup func() + expectedStatusCode int + request func() *resty.Request + validate func(t *testing.T, response *resty.Response) + }{ + { + name: "Create webhook with valid data", + setup: func() { + // No setup needed + }, + expectedStatusCode: 201, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + URL: "https://example.com/webhook", + Event: "sms:received", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 201 { + t.Fatal(response.StatusCode(), response.String()) + } + + var result webhook + if err := json.Unmarshal(response.Body(), &result); err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + _, err := authorizedClient.R().Delete("webhooks/" + result.ID) + if err != nil { + t.Error(err) + } + }) + + // Verify response structure + if result.ID == "" { + t.Error("webhook ID is empty") + } + if result.URL != "https://example.com/webhook" { + t.Errorf("expected URL 'https://example.com/webhook', got '%s'", result.URL) + } + if result.Event != "sms:received" { + t.Errorf("expected event 'sms:received', got '%s'", result.Event) + } + }, + }, + { + name: "Create webhook with device_id", + setup: func() { + // No setup needed + }, + expectedStatusCode: 201, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + DeviceID: credentials.ID, + URL: "https://example.com/device-webhook", + Event: "sms:sent", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 201 { + t.Fatal(response.StatusCode(), response.String()) + } + + var result webhook + if err := json.Unmarshal(response.Body(), &result); err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + _, err := authorizedClient.R().Delete("webhooks/" + result.ID) + if err != nil { + t.Error(err) + } + }) + + // Verify response structure + if result.ID == "" { + t.Error("webhook ID is empty") + } + if result.DeviceID != credentials.ID { + t.Errorf("expected device_id '%s', got '%s'", credentials.ID, result.DeviceID) + } + if result.URL != "https://example.com/device-webhook" { + t.Errorf("expected URL 'https://example.com/device-webhook', got '%s'", result.URL) + } + if result.Event != "sms:sent" { + t.Errorf("expected event 'sms:sent', got '%s'", result.Event) + } + }, + }, + { + name: "Create webhook with different event types", + setup: func() { + // No setup needed + }, + expectedStatusCode: 201, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + URL: "https://example.com/data-webhook", + Event: "sms:data-received", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 201 { + t.Fatal(response.StatusCode(), response.String()) + } + + var result webhook + if err := json.Unmarshal(response.Body(), &result); err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + _, err := authorizedClient.R().Delete("webhooks/" + result.ID) + if err != nil { + t.Error(err) + } + }) + + // Verify response structure + if result.Event != "sms:data-received" { + t.Errorf("expected event 'sms:data-received', got '%s'", result.Event) + } + }, + }, + { + name: "Invalid URL format", + setup: func() { + // No setup needed + }, + expectedStatusCode: 400, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + URL: "invalid-url", + Event: "sms:received", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 400 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Missing required fields", + setup: func() { + // No setup needed + }, + expectedStatusCode: 400, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + // Missing URL and Event + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 400 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Invalid event type", + setup: func() { + // No setup needed + }, + expectedStatusCode: 400, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + URL: "https://example.com/webhook", + Event: "invalid:event", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 400 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Invalid device_id length", + setup: func() { + // No setup needed + }, + expectedStatusCode: 400, + request: func() *resty.Request { + return authorizedClient.R(). + SetBody(webhook{ + DeviceID: "invalid_length_device_id", + URL: "https://example.com/webhook", + Event: "sms:received", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 400 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Missing authentication", + setup: func() { + // No setup needed + }, + expectedStatusCode: 401, + request: func() *resty.Request { + return publicUserClient.R(). + SetBody(webhook{ + URL: "https://example.com/webhook", + Event: "sms:received", + }) + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 401 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Invalid credentials", + setup: func() { + // No setup needed + }, + expectedStatusCode: 401, + request: func() *resty.Request { + return publicUserClient.R(). + SetBody(webhook{ + URL: "https://example.com/webhook", + Event: "sms:received", + }).SetBasicAuth("invalid", "credentials") + }, + validate: func(t *testing.T, response *resty.Response) { + if response.StatusCode() != 401 { + t.Fatal(response.StatusCode(), response.String()) + } + + var err errorResponse + if err := json.Unmarshal(response.Body(), &err); err != nil { + t.Fatal(err) + } + + if err.Message == "" { + t.Error("expected error message in response") + } + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + c.setup() + + res, err := c.request().Post("webhooks") + if err != nil { + t.Fatal(err) + } + + if res.StatusCode() != c.expectedStatusCode { + t.Fatal(res.StatusCode(), res.String()) + } + + if c.validate != nil { + c.validate(t, res) + } + }) + } +} + +func TestWebhooks_Delete(t *testing.T) { + credentials := mobileDeviceRegister(t, publicMobileClient) + authorizedClient := publicUserClient.Clone().SetBasicAuth(credentials.Login, credentials.Password) + + cases := []struct { + name string + setup func() string + expectedStatusCode int + request func(id string) *resty.Request + validate func(t *testing.T, response *resty.Response) + }{ + { + name: "Remove webhook by ID", + setup: func() string { + // Create a webhook to delete (server will generate ID) + resp, err := authorizedClient.R(). + SetBody(webhook{ + URL: "https://example.com/delete-test", + Event: "sms:failed", + }).Post("webhooks") + if err != nil { + t.Fatal(err) + } + + var created webhook + if err := json.Unmarshal(resp.Body(), &created); err != nil { + t.Fatal(err) + } + + return created.ID + }, + expectedStatusCode: 204, + request: func(id string) *resty.Request { + return authorizedClient.R().SetPathParam("id", id) + }, + validate: func(t *testing.T, response *resty.Response) { + if len(response.Body()) != 0 { + t.Error("expected empty response body for 204 status") + } + }, + }, + { + name: "Remove non-existent webhook", + setup: func() string { + // No setup needed + return "" + }, + expectedStatusCode: 204, + request: func(id string) *resty.Request { + return authorizedClient.R().SetPathParam("id", "non-existent-id") + }, + validate: func(t *testing.T, response *resty.Response) { + if len(response.Body()) != 0 { + t.Error("expected empty response body for 204 status") + } + }, + }, + { + name: "Missing authentication", + setup: func() string { + // No setup needed + return "" + }, + expectedStatusCode: 401, + request: func(id string) *resty.Request { + return publicUserClient.R().SetPathParam("id", "test-id") + }, + validate: func(t *testing.T, response *resty.Response) { + var errResp errorResponse + if err := json.Unmarshal(response.Body(), &errResp); err != nil { + t.Fatal(err) + } + + if errResp.Message == "" { + t.Error("expected error message in response") + } + }, + }, + { + name: "Invalid credentials", + setup: func() string { + // No setup needed + return "" + }, + expectedStatusCode: 401, + request: func(id string) *resty.Request { + return publicUserClient.R().SetBasicAuth("invalid", "credentials").SetPathParam("id", "test-id") + }, + validate: func(t *testing.T, response *resty.Response) { + var errResp errorResponse + if err := json.Unmarshal(response.Body(), &errResp); err != nil { + t.Fatal(err) + } + + if errResp.Message == "" { + t.Error("expected error message in response") + } + }, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + webhookID := c.setup() + // Clean up the webhook if one was created + if webhookID != "" { + t.Cleanup(func() { + authorizedClient.R().Delete("webhooks/" + webhookID) + }) + } + + res, err := c.request(webhookID).Delete("webhooks/{id}") + if err != nil { + t.Fatal(err) + } + + if res.StatusCode() != c.expectedStatusCode { + t.Fatal(res.StatusCode(), res.String()) + } + + if c.validate != nil { + c.validate(t, res) + } + + }) + } +}