diff --git a/.golangci.yml b/.golangci.yml index d3c3aee3..08eb6726 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -252,12 +252,12 @@ linters: - ^os/exec.Cmd$ - ^reflect.StructField$ # public libs - - "^github.com/gofiber/.+Config$" - - "^gopkg.in/telebot.v4.LongPoller$" - - "^gopkg.in/telebot.v4.ReplyMarkup$" - - "^gopkg.in/telebot.v4.Settings$" + - ^firebase.google.com/go/v4/messaging.AndroidConfig$ + - ^firebase.google.com/go/v4/messaging.Message$ - ^github.com/aws/aws-sdk-go-v2/service/s3.+Input$ - ^github.com/aws/aws-sdk-go-v2/service/s3/types.ObjectIdentifier$ + - ^github.com/gofiber/.+Config$ + - ^github.com/golang-jwt/jwt/v5.+Claims$ - ^github.com/mitchellh/mapstructure.DecoderConfig$ - ^github.com/prometheus/client_golang/.+Opts$ - ^github.com/secsy/goftp.Config$ @@ -273,10 +273,11 @@ linters: - ^github.com/urfave/cli.v3.FlagBase$ - ^golang.org/x/tools/go/analysis.Analyzer$ - ^google.golang.org/protobuf/.+Options$ + - ^gopkg.in/telebot.v4.LongPoller$ + - ^gopkg.in/telebot.v4.ReplyMarkup$ + - ^gopkg.in/telebot.v4.Settings$ - ^gopkg.in/yaml.v3.Node$ - ^gorm.io/gorm/clause.+$ - - ^firebase.google.com/go/v4/messaging.Message$ - - ^firebase.google.com/go/v4/messaging.AndroidConfig$ # Allows empty structures in return statements. # Default: false allow-empty-returns: true diff --git a/api/mobile.http b/api/mobile.http index 9d8abf36..3f6d06ed 100644 --- a/api/mobile.http +++ b/api/mobile.http @@ -68,8 +68,8 @@ Authorization: Bearer {{mobileToken}} Content-Type: application/json { - "currentPassword": "wsmgz1akhoo24o", - "newPassword": "wsmgz1akhoo24o" + "currentPassword": "8f8ijpnuvemq7y", + "newPassword": "8f8ijpnuvemq7y" } ### diff --git a/api/requests.http b/api/requests.http index b614a5b0..0b87365d 100644 --- a/api/requests.http +++ b/api/requests.http @@ -1,6 +1,7 @@ @baseUrl={{$dotenv CLOUD__URL}} @credentials={{$dotenv CLOUD__CREDENTIALS}} @mobileToken={{$dotenv MOBILE__TOKEN}} +@jwtToken={{$dotenv JWT__TOKEN}} @phone={{$dotenv PHONE}} ### @@ -34,7 +35,8 @@ Authorization: Basic {{credentials}} ### POST {{baseUrl}}/3rdparty/v1/messages HTTP/1.1 Content-Type: application/json -Authorization: Basic {{credentials}} +# Authorization: Basic {{credentials}} +Authorization: Bearer {{jwtToken}} { "textMessage": { @@ -78,11 +80,13 @@ Authorization: Basic {{credentials}} ### GET {{baseUrl}}/3rdparty/v1/messages/Fc10ZyTRDVlqPjIm9Jbly HTTP/1.1 -Authorization: Basic {{credentials}} +# Authorization: Basic {{credentials}} +Authorization: Bearer {{jwtToken}} ### GET {{baseUrl}}/3rdparty/v1/messages HTTP/1.1 -Authorization: Basic {{credentials}} +# Authorization: Basic {{credentials}} +Authorization: Bearer {{jwtToken}} ### GET {{baseUrl}}/3rdparty/v1/messages?from=2025-01-01T00:00:00.000Z&to=2025-12-31T23:59:59Z&state=Pending&deviceId=fL2m4IirEvh9BvTf6TIB0&limit=50&offset=0 HTTP/1.1 @@ -101,7 +105,9 @@ Content-Type: application/json ### GET {{baseUrl}}/3rdparty/v1/devices HTTP/1.1 -Authorization: Basic {{credentials}} +# Authorization: Basic {{credentials}} +Authorization: Bearer {{jwtToken}} + ### DELETE {{baseUrl}}/3rdparty/v1/devices/gF0jEYiaG_x9sI1YFWa7a HTTP/1.1 @@ -191,6 +197,31 @@ Content-Type: application/json } } +### +POST {{baseUrl}}/3rdparty/v1/auth/token HTTP/1.1 +Authorization: Basic {{credentials}} +Content-Type: application/json + +{ + "ttl": 3600, + "scopes": [ + "messages:send", + "messages:read", + "devices:list", + "devices:write", + "webhooks:list", + "webhooks:write", + "settings:read", + "settings:write", + "logs:read" + ] +} + +### +DELETE {{baseUrl}}/3rdparty/v1/auth/token/w8pxz0a4Fwa4xgzyCvSeC HTTP/1.1 +Authorization: Basic {{credentials}} +Content-Type: application/json + ### GET http://localhost:3000/metrics HTTP/1.1 diff --git a/cmd/sms-gateway/main.go b/cmd/sms-gateway/main.go index ed896ee4..5ebf583d 100644 --- a/cmd/sms-gateway/main.go +++ b/cmd/sms-gateway/main.go @@ -14,6 +14,11 @@ const ( // @securitydefinitions.basic ApiAuth // @description User authentication +// @securitydefinitions.apikey JWTAuth +// @in header +// @name Authorization +// @description JWT authentication + // @securitydefinitions.apikey UserCode // @in header // @name Authorization diff --git a/configs/config.example.yml b/configs/config.example.yml index f50c77b9..15bdc788 100644 --- a/configs/config.example.yml +++ b/configs/config.example.yml @@ -38,15 +38,19 @@ cache: # cache config url: memory:// # cache url (memory:// or redis://) [CACHE__URL] pubsub: # pubsub config url: memory:// # pubsub url (memory:// or redis://) [PUBSUB__URL] +jwt: + secret: # jwt secret (leave empty to disable JWT functionality) [JWT__SECRET] + ttl: 24h # jwt ttl [JWT__TTL] + issuer: # jwt issuer [JWT__ISSUER] ## Worker Config ## tasks: # tasks config messages_hashing: - interval: 168h # task execution interval in hours [TASKS__MESSAGES_HASHING__INTERVAL] + interval: 168h # task execution interval [TASKS__MESSAGES_HASHING__INTERVAL] messages_cleanup: - interval: 24h # task execution interval in hours [TASKS__MESSAGES_CLEANUP__INTERVAL] - max_age: 720h # messages max age in hours [TASKS__MESSAGES_CLEANUP__MAX_AGE] + interval: 24h # task execution interval [TASKS__MESSAGES_CLEANUP__INTERVAL] + max_age: 720h # messages max age [TASKS__MESSAGES_CLEANUP__MAX_AGE] devices_cleanup: - interval: 24h # task execution interval in hours [TASKS__DEVICES_CLEANUP__INTERVAL] - max_age: 8760h # inactive devices max age in hours [TASKS__DEVICES_CLEANUP__MAX_AGE] + interval: 24h # task execution interval [TASKS__DEVICES_CLEANUP__INTERVAL] + max_age: 8760h # inactive devices max age [TASKS__DEVICES_CLEANUP__MAX_AGE] diff --git a/deployments/grafana/dashboards/jwt.json b/deployments/grafana/dashboards/jwt.json new file mode 100644 index 00000000..0e595256 --- /dev/null +++ b/deployments/grafana/dashboards/jwt.json @@ -0,0 +1,1192 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + }, + { + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": false, + "iconColor": "blue", + "name": "Last Updated", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "id": 0, + "links": [], + "panels": [ + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 8, + "x": 0, + "y": 0 + }, + "id": 1, + "options": { + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "text": {}, + "textMode": "auto" + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum(rate(sms_auth_jwt_tokens_issued_total{instance=~\"$instance\",job=~\"$job\"}[5m]))", + "legendFormat": "Tokens Issued", + "range": true, + "refId": "A" + } + ], + "title": "Total Tokens Issued", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "mappings": [], + "max": 100, + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "red", + "value": 0 + }, + { + "color": "green", + "value": 90 + } + ] + }, + "unit": "percent" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 8, + "x": 8, + "y": 0 + }, + "id": 2, + "options": { + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "text": {}, + "textMode": "auto" + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "100 * sum(rate(sms_auth_jwt_tokens_validated_total{status=\"success\",instance=~\"$instance\",job=~\"$job\"}[5m])) / clamp_min(sum(rate(sms_auth_jwt_tokens_validated_total{instance=~\"$instance\",job=~\"$job\"}[5m])), 1)", + "legendFormat": "Validation Success Rate", + "range": true, + "refId": "A" + } + ], + "title": "Validation Success Rate", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 8, + "x": 16, + "y": 0 + }, + "id": 3, + "options": { + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "text": {}, + "textMode": "auto" + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum(rate(sms_auth_jwt_tokens_revoked_total{instance=~\"$instance\",job=~\"$job\"}[5m]))", + "legendFormat": "Tokens Revoked", + "range": true, + "refId": "A" + } + ], + "title": "Revocation Rate", + "type": "stat" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "reqps" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 8 + }, + "id": 4, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum by (status) (rate(sms_auth_jwt_tokens_issued_total{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval]))", + "legendFormat": "{{status}}", + "range": true, + "refId": "A" + } + ], + "title": "Token Issuance Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "mappings": [], + "min": 0, + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "short" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 8 + }, + "id": 5, + "options": { + "displayMode": "gradient", + "legend": { + "calcs": [], + "displayMode": "list", + "placement": "bottom", + "showLegend": false + }, + "maxVizHeight": 300, + "minVizHeight": 75, + "minVizWidth": 75, + "namePlacement": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showThresholdLabels": false, + "showThresholdMarkers": true, + "showUnfilled": true, + "sizing": "auto", + "valueMode": "color" + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum by (status) (sms_auth_jwt_tokens_issued_total{status=\"error\",instance=~\"$instance\",job=~\"$job\"})", + "legendFormat": "{{status}}", + "range": true, + "refId": "A" + } + ], + "title": "Issuance Error Count", + "type": "bargauge" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 16 + }, + "id": 6, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(sms_auth_jwt_issuance_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p99", + "range": true, + "refId": "A" + }, + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum(rate(sms_auth_jwt_issuance_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p95", + "range": true, + "refId": "B" + }, + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(sms_auth_jwt_issuance_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p50", + "range": true, + "refId": "C" + } + ], + "title": "Token Issuance Latency Distribution", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "percentunit" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 24 + }, + "id": 7, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum(rate(sms_auth_jwt_tokens_validated_total{status=\"success\",instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) / clamp_min(sum(rate(sms_auth_jwt_tokens_validated_total{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])), 1)", + "legendFormat": "Success Rate", + "range": true, + "refId": "A" + } + ], + "title": "Validation Success Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + } + }, + "mappings": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 24 + }, + "id": 8, + "options": { + "legend": { + "displayMode": "list", + "placement": "bottom", + "showLegend": true + }, + "pieType": "pie", + "reduceOptions": { + "calcs": [ + "last" + ], + "fields": "", + "values": false + }, + "tooltip": { + "hideZeros": false, + "mode": "single", + "sort": "none" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "exemplar": false, + "expr": "sum by (status) (sms_auth_jwt_tokens_validated_total{instance=~\"$instance\",job=~\"$job\"})", + "format": "heatmap", + "instant": true, + "legendFormat": "{{status}}", + "range": false, + "refId": "A" + } + ], + "title": "Validation Status Classification", + "type": "piechart" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 32 + }, + "id": 9, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(sms_auth_jwt_validation_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p99", + "range": true, + "refId": "A" + }, + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum(rate(sms_auth_jwt_validation_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p95", + "range": true, + "refId": "B" + }, + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(sms_auth_jwt_validation_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p50", + "range": true, + "refId": "C" + } + ], + "title": "Token Validation Latency", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "reqps" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 40 + }, + "id": 10, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum by (status) (rate(sms_auth_jwt_tokens_revoked_total{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval]))", + "legendFormat": "{{status}}", + "range": true, + "refId": "A" + } + ], + "title": "Token Revocation Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "percentunit" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 40 + }, + "id": 11, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "sum(rate(sms_auth_jwt_tokens_revoked_total{status=\"error\",instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) / clamp_min(sum(rate(sms_auth_jwt_tokens_revoked_total{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])), 1)", + "legendFormat": "Error Rate", + "range": true, + "refId": "A" + } + ], + "title": "Revocation Error Rate", + "type": "timeseries" + }, + { + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "fieldConfig": { + "defaults": { + "color": { + "mode": "palette-classic" + }, + "custom": { + "axisBorderShow": false, + "axisCenteredZero": false, + "axisColorMode": "text", + "axisLabel": "", + "axisPlacement": "auto", + "barAlignment": 0, + "barWidthFactor": 0.6, + "drawStyle": "line", + "fillOpacity": 0, + "gradientMode": "none", + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "insertNulls": false, + "lineInterpolation": "linear", + "lineWidth": 1, + "pointSize": 5, + "scaleDistribution": { + "type": "linear" + }, + "showPoints": "auto", + "spanNulls": false, + "stacking": { + "group": "A", + "mode": "none" + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": 0 + }, + { + "color": "red", + "value": 80 + } + ] + }, + "unit": "s" + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 24, + "x": 0, + "y": 48 + }, + "id": 12, + "options": { + "legend": { + "calcs": [ + "mean", + "max" + ], + "displayMode": "table", + "placement": "bottom", + "showLegend": true + }, + "tooltip": { + "hideZeros": false, + "mode": "multi", + "sort": "desc" + } + }, + "pluginVersion": "12.2.0", + "targets": [ + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.99, sum(rate(sms_auth_jwt_revocation_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p99", + "range": true, + "refId": "A" + }, + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.95, sum(rate(sms_auth_jwt_revocation_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p95", + "range": true, + "refId": "B" + }, + { + "datasource": { + "uid": "Prometheus" + }, + "editorMode": "code", + "expr": "histogram_quantile(0.50, sum(rate(sms_auth_jwt_revocation_duration_seconds_bucket{instance=~\"$instance\",job=~\"$job\"}[$__rate_interval])) by (le))", + "legendFormat": "p50", + "range": true, + "refId": "C" + } + ], + "title": "Token Revocation Latency", + "type": "timeseries" + } + ], + "preload": false, + "refresh": "auto", + "schemaVersion": 42, + "tags": [ + "auth", + "jwt" + ], + "templating": { + "list": [ + { + "current": { + "text": "All", + "value": [ + "$__all" + ] + }, + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "definition": "label_values(sms_auth_jwt_tokens_issued_total, instance)", + "includeAll": true, + "multi": true, + "name": "instance", + "options": [], + "query": { + "query": "label_values(sms_auth_jwt_tokens_issued_total, instance)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "type": "query" + }, + { + "current": { + "text": "All", + "value": [ + "$__all" + ] + }, + "datasource": { + "type": "prometheus", + "uid": "edqp0a73uh2bka" + }, + "definition": "label_values(sms_auth_jwt_tokens_issued_total, job)", + "includeAll": true, + "multi": true, + "name": "job", + "options": [], + "query": { + "query": "label_values(sms_auth_jwt_tokens_issued_total, job)", + "refId": "StandardVariableQuery" + }, + "refresh": 1, + "type": "query" + } + ] + }, + "time": { + "from": "now-24h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "JWT Authentication", + "uid": "jwt", + "version": 1 +} \ No newline at end of file diff --git a/deployments/prometheus/alerts/jwt-alerts.yml b/deployments/prometheus/alerts/jwt-alerts.yml new file mode 100644 index 00000000..13b91e48 --- /dev/null +++ b/deployments/prometheus/alerts/jwt-alerts.yml @@ -0,0 +1,42 @@ +groups: + - name: jwt-alerts + rules: + - alert: JWT_Validation_ErrorRate_High + expr: rate(sms_auth_jwt_tokens_validated_total{status="error"}[5m]) / clamp_min(rate(sms_auth_jwt_tokens_validated_total[5m]), 1e-9) > 0.1 + for: 5m + labels: + severity: warning + service: jwt + annotations: + summary: "High JWT validation error rate" + description: "JWT validation error rate is {{ $value | humanizePercentage }} (threshold: 10%)" + + - alert: JWT_Issuance_Latency_High + expr: histogram_quantile(0.99, rate(sms_auth_jwt_issuance_duration_seconds_bucket[5m])) > 0.5 + for: 10m + labels: + severity: warning + service: jwt + annotations: + summary: "High JWT issuance latency (p99)" + description: "JWT issuance p99 latency is {{ $value }}s (threshold: 0.5s)" + + - alert: JWT_Revocation_Failures + expr: rate(sms_auth_jwt_tokens_revoked_total{status="error"}[5m]) > 0 + for: 5m + labels: + severity: critical + service: jwt + annotations: + summary: "JWT revocation failures detected" + description: "JWT revocation errors occurring at rate {{ $value }}/s" + + - alert: JWT_Validation_Failures_High + expr: rate(sms_auth_jwt_tokens_validated_total{status="error"}[5m]) / max(rate(sms_auth_jwt_tokens_validated_total[5m]), 0.00001) > 0.1 + for: 5m + labels: + severity: warning + service: jwt + annotations: + summary: "High JWT validation failure rate" + description: "JWT validation failure rate is {{ $value | humanizePercentage }} (threshold: 10%)" diff --git a/go.mod b/go.mod index b7ee483b..0cbc35fc 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.3 require ( firebase.google.com/go/v4 v4.12.1 - github.com/android-sms-gateway/client-go v1.9.5 + github.com/android-sms-gateway/client-go v1.9.6-0.20251123133512-f7816d96f90a github.com/ansrivas/fiberprometheus/v2 v2.6.1 github.com/capcom6/go-helpers v0.3.0 github.com/capcom6/go-infra-fx v0.5.2 @@ -15,6 +15,7 @@ require ( github.com/go-sql-driver/mysql v1.7.1 github.com/gofiber/fiber/v2 v2.52.9 github.com/gofiber/swagger v1.1.1 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 github.com/jaevor/go-nanoid v1.3.0 github.com/nyaruka/phonenumbers v1.4.0 diff --git a/go.sum b/go.sum index 0552a6f7..682f2bdc 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/android-sms-gateway/client-go v1.9.5 h1:fHrE1Pi3rKUdPVMmI9evKW0iyjB5bMIhFRxyq1wVQ+o= github.com/android-sms-gateway/client-go v1.9.5/go.mod h1:DQsReciU1xcaVW3T5Z2bqslNdsAwCFCtghawmA6g6L4= +github.com/android-sms-gateway/client-go v1.9.6-0.20251123133512-f7816d96f90a h1:Tm1FDTqFRs1ZftaEmQqDdIXtMRZf2aGCp8t2BgXY/rs= +github.com/android-sms-gateway/client-go v1.9.6-0.20251123133512-f7816d96f90a/go.mod h1:DQsReciU1xcaVW3T5Z2bqslNdsAwCFCtghawmA6g6L4= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/ansrivas/fiberprometheus/v2 v2.6.1 h1:wac3pXaE6BYYTF04AC6K0ktk6vCD+MnDOJZ3SK66kXM= @@ -128,6 +130,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69 github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= diff --git a/internal/config/config.go b/internal/config/config.go index 656daa2a..df60cbf9 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,5 +1,7 @@ package config +import "time" + type GatewayMode string const ( @@ -17,6 +19,7 @@ type Config struct { Messages Messages `yaml:"messages"` // messages config Cache Cache `yaml:"cache"` // cache (memory or redis) config PubSub PubSub `yaml:"pubsub"` // pubsub (memory or redis) config + JWT JWT `yaml:"jwt"` // jwt config } type Gateway struct { @@ -86,6 +89,12 @@ type PubSub struct { BufferSize uint `yaml:"buffer_size" envconfig:"PUBSUB__BUFFER_SIZE"` } +type JWT struct { + Secret string `yaml:"secret" envconfig:"JWT__SECRET"` + TTL Duration `yaml:"ttl" envconfig:"JWT__TTL"` + Issuer string `yaml:"issuer" envconfig:"JWT__ISSUER"` +} + func Default() Config { //nolint:exhaustruct,mnd // default values return Config{ @@ -123,5 +132,9 @@ func Default() Config { URL: "memory://", BufferSize: 128, }, + JWT: JWT{ + TTL: Duration(time.Hour * 24), + Issuer: "sms-gate.app", + }, } } diff --git a/internal/config/module.go b/internal/config/module.go index 4d17c9bb..0418a974 100644 --- a/internal/config/module.go +++ b/internal/config/module.go @@ -6,6 +6,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/cache" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers" + "github.com/android-sms-gateway/server/internal/sms-gateway/jwt" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/auth" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/messages" @@ -136,5 +137,12 @@ func Module() fx.Option { BufferSize: cfg.PubSub.BufferSize, } }), + fx.Provide(func(cfg Config) jwt.Config { + return jwt.Config{ + Secret: cfg.JWT.Secret, + TTL: time.Duration(cfg.JWT.TTL), + Issuer: cfg.JWT.Issuer, + } + }), ) } diff --git a/internal/config/types.go b/internal/config/types.go new file mode 100644 index 00000000..82e12439 --- /dev/null +++ b/internal/config/types.go @@ -0,0 +1,48 @@ +package config + +import ( + "encoding" + "fmt" + "time" + + "gopkg.in/yaml.v3" +) + +type Duration time.Duration + +// Duration returns the underlying time.Duration value. +func (d *Duration) Duration() time.Duration { + if d == nil { + return 0 + } + return time.Duration(*d) +} + +// String returns the string representation of the duration. +func (d *Duration) String() string { + if d == nil { + return "" + } + return time.Duration(*d).String() +} + +func (d *Duration) UnmarshalText(text []byte) error { + t, err := time.ParseDuration(string(text)) + if err != nil { + return fmt.Errorf("can't parse duration: %w", err) + } + *d = Duration(t) + return nil +} + +func (d *Duration) UnmarshalYAML(value *yaml.Node) error { + var s string + if err := value.Decode(&s); err != nil { + return fmt.Errorf("can't unmarshal duration: %w", err) + } + + return d.UnmarshalText([]byte(s)) +} + +var _ yaml.Unmarshaler = (*Duration)(nil) +var _ encoding.TextUnmarshaler = (*Duration)(nil) diff --git a/internal/sms-gateway/app.go b/internal/sms-gateway/app.go index dc649939..f34de636 100644 --- a/internal/sms-gateway/app.go +++ b/internal/sms-gateway/app.go @@ -7,6 +7,7 @@ import ( appconfig "github.com/android-sms-gateway/server/internal/config" "github.com/android-sms-gateway/server/internal/sms-gateway/cache" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers" + "github.com/android-sms-gateway/server/internal/sms-gateway/jwt" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/auth" appdb "github.com/android-sms-gateway/server/internal/sms-gateway/modules/db" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" @@ -20,6 +21,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/online" "github.com/android-sms-gateway/server/internal/sms-gateway/openapi" "github.com/android-sms-gateway/server/internal/sms-gateway/pubsub" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/android-sms-gateway/server/pkg/health" "github.com/capcom6/go-infra-fx/cli" "github.com/capcom6/go-infra-fx/db" @@ -40,6 +42,7 @@ func Module() fx.Option { validator.Module, openapi.Module(), handlers.Module(), + users.Module(), auth.Module(), push.Module(), db.Module, @@ -54,6 +57,7 @@ func Module() fx.Option { metrics.Module(), sse.Module(), online.Module(), + jwt.Module(), ) } diff --git a/internal/sms-gateway/handlers/3rdparty.go b/internal/sms-gateway/handlers/3rdparty.go index 6acc1dd2..36d0ba74 100644 --- a/internal/sms-gateway/handlers/3rdparty.go +++ b/internal/sms-gateway/handlers/3rdparty.go @@ -5,43 +5,65 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/logs" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/messages" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/jwtauth" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/settings" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/thirdparty" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/webhooks" - "github.com/android-sms-gateway/server/internal/sms-gateway/modules/auth" + "github.com/android-sms-gateway/server/internal/sms-gateway/jwt" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" - "go.uber.org/fx" "go.uber.org/zap" ) -type ThirdPartyHandlerParams struct { - fx.In - - HealthHandler *HealthHandler - MessagesHandler *messages.ThirdPartyController - WebhooksHandler *webhooks.ThirdPartyController - DevicesHandler *devices.ThirdPartyController - SettingsHandler *settings.ThirdPartyController - LogsHandler *logs.ThirdPartyController - - AuthSvc *auth.Service - - Logger *zap.Logger - Validator *validator.Validate -} - type thirdPartyHandler struct { base.Handler + usersSvc *users.Service + jwtSvc jwt.Service + healthHandler *HealthHandler messagesHandler *messages.ThirdPartyController webhooksHandler *webhooks.ThirdPartyController devicesHandler *devices.ThirdPartyController settingsHandler *settings.ThirdPartyController logsHandler *logs.ThirdPartyController + authHandler *thirdparty.AuthHandler +} - authSvc *auth.Service +func newThirdPartyHandler( + usersSvc *users.Service, + jwtService jwt.Service, + + healthHandler *HealthHandler, + messagesHandler *messages.ThirdPartyController, + webhooksHandler *webhooks.ThirdPartyController, + devicesHandler *devices.ThirdPartyController, + settingsHandler *settings.ThirdPartyController, + logsHandler *logs.ThirdPartyController, + authHandler *thirdparty.AuthHandler, + + logger *zap.Logger, + validator *validator.Validate, +) *thirdPartyHandler { + return &thirdPartyHandler{ + Handler: base.Handler{ + Logger: logger, + Validator: validator, + }, + + usersSvc: usersSvc, + jwtSvc: jwtService, + + healthHandler: healthHandler, + messagesHandler: messagesHandler, + webhooksHandler: webhooksHandler, + devicesHandler: devicesHandler, + settingsHandler: settingsHandler, + logsHandler: logsHandler, + authHandler: authHandler, + } } func (h *thirdPartyHandler) Register(router fiber.Router) { @@ -50,10 +72,13 @@ func (h *thirdPartyHandler) Register(router fiber.Router) { h.healthHandler.Register(router) router.Use( - userauth.NewBasic(h.authSvc), + userauth.NewBasic(h.usersSvc), + jwtauth.NewJWT(h.jwtSvc, h.usersSvc), userauth.UserRequired(), ) + h.authHandler.Register(router.Group("/auth")) + h.messagesHandler.Register(router.Group("/message")) // TODO: remove after 2025-12-31 h.messagesHandler.Register(router.Group("/messages")) @@ -66,16 +91,3 @@ func (h *thirdPartyHandler) Register(router fiber.Router) { h.logsHandler.Register(router.Group("/logs")) } - -func newThirdPartyHandler(params ThirdPartyHandlerParams) *thirdPartyHandler { - return &thirdPartyHandler{ - Handler: base.Handler{Logger: params.Logger.Named("ThirdPartyHandler"), Validator: params.Validator}, - healthHandler: params.HealthHandler, - messagesHandler: params.MessagesHandler, - webhooksHandler: params.WebhooksHandler, - devicesHandler: params.DevicesHandler, - settingsHandler: params.SettingsHandler, - logsHandler: params.LogsHandler, - authSvc: params.AuthSvc, - } -} diff --git a/internal/sms-gateway/handlers/base/handler.go b/internal/sms-gateway/handlers/base/handler.go index 6b574611..198b743d 100644 --- a/internal/sms-gateway/handlers/base/handler.go +++ b/internal/sms-gateway/handlers/base/handler.go @@ -19,7 +19,7 @@ type Handler struct { func (h *Handler) BodyParserValidator(c *fiber.Ctx, out any) error { if err := c.BodyParser(out); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse body: %s", err.Error())) + return fmt.Errorf("failed to parse body: %w", err) } return h.ValidateStruct(out) @@ -27,7 +27,7 @@ func (h *Handler) BodyParserValidator(c *fiber.Ctx, out any) error { func (h *Handler) QueryParserValidator(c *fiber.Ctx, out any) error { if err := c.QueryParser(out); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse query: %s", err.Error())) + return fmt.Errorf("failed to parse query: %w", err) } return h.ValidateStruct(out) @@ -35,7 +35,7 @@ func (h *Handler) QueryParserValidator(c *fiber.Ctx, out any) error { func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error { if err := c.ParamsParser(out); err != nil { - return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("failed to parse params: %s", err.Error())) + return fmt.Errorf("failed to parse params: %w", err) } return h.ValidateStruct(out) @@ -44,13 +44,13 @@ func (h *Handler) ParamsParserValidator(c *fiber.Ctx, out any) error { func (h *Handler) ValidateStruct(out any) error { if h.Validator != nil { if err := h.Validator.Var(out, "required,dive"); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return fmt.Errorf("failed to validate: %w", err) } } if req, ok := out.(Validatable); ok { if err := req.Validate(); err != nil { - return fiber.NewError(fiber.StatusBadRequest, err.Error()) + return fmt.Errorf("failed to validate: %w", err) } } diff --git a/internal/sms-gateway/handlers/base/handler_test.go b/internal/sms-gateway/handlers/base/handler_test.go index 6b1ebc98..a257f46b 100644 --- a/internal/sms-gateway/handlers/base/handler_test.go +++ b/internal/sms-gateway/handlers/base/handler_test.go @@ -90,13 +90,13 @@ func TestHandler_BodyParserValidator(t *testing.T) { description: "Invalid request body - missing name", path: "/test", payload: &testRequestBody{Age: 25}, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Invalid request body - age too low", path: "/test", payload: &testRequestBody{Name: "John Doe", Age: 17}, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Valid request body - no validation", @@ -108,7 +108,7 @@ func TestHandler_BodyParserValidator(t *testing.T) { description: "No request body", path: "/test", payload: nil, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusUnprocessableEntity, }, } @@ -157,7 +157,7 @@ func TestHandler_QueryParserValidator(t *testing.T) { { description: "Invalid query parameters - non-integer age", path: "/test?name=John&age=abc", - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Valid query parameters", @@ -167,17 +167,17 @@ func TestHandler_QueryParserValidator(t *testing.T) { { description: "Invalid query parameters - missing name", path: "/test?age=25", - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Invalid query parameters - age too low", path: "/test?name=John&age=17", - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Invalid query parameters - missing age", path: "/test?name=John", - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, } @@ -234,7 +234,7 @@ func TestHandler_ParamsParserValidator(t *testing.T) { { description: "Invalid path parameters - invalid ID", path: "/test/invalid/John", - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, } @@ -285,13 +285,13 @@ func TestHandler_ValidateStruct(t *testing.T) { description: "Invalid struct with validator - missing required field", handler: handlerWithValidator, input: &testRequestBody{Age: 25}, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Invalid struct with validator - custom validation fails", handler: handlerWithValidator, input: &testRequestBody{Name: "John Doe", Age: 17}, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Valid struct without validator", @@ -303,7 +303,7 @@ func TestHandler_ValidateStruct(t *testing.T) { description: "Invalid struct without validator - custom validation fails", handler: handlerWithoutValidator, input: &testRequestBody{Name: "John Doe", Age: 17}, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, { description: "Valid struct with Validatable interface", @@ -315,7 +315,7 @@ func TestHandler_ValidateStruct(t *testing.T) { description: "Invalid struct with Validatable interface", handler: handlerWithValidator, input: &testRequestQuery{Name: "John", Age: 17}, - expectedStatus: fiber.StatusBadRequest, + expectedStatus: fiber.StatusInternalServerError, }, } @@ -327,7 +327,7 @@ func TestHandler_ValidateStruct(t *testing.T) { t.Errorf("Expected no error, got %v", err) } - if test.expectedStatus == fiber.StatusBadRequest && err == nil { + if test.expectedStatus == fiber.StatusInternalServerError && err == nil { t.Errorf("Expected error, got nil") } }) diff --git a/internal/sms-gateway/handlers/devices/3rdparty.go b/internal/sms-gateway/handlers/devices/3rdparty.go index e44491e7..67bb2cf0 100644 --- a/internal/sms-gateway/handlers/devices/3rdparty.go +++ b/internal/sms-gateway/handlers/devices/3rdparty.go @@ -6,9 +6,10 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/converters" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" - "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/capcom6/go-helpers/slices" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" @@ -38,6 +39,7 @@ func NewThirdPartyController( // @Summary List devices // @Description Returns list of registered devices // @Security ApiAuth +// @Security JWTAuth // @Tags User, Devices // @Produce json // @Success 200 {object} []smsgateway.Device "Device list" @@ -47,7 +49,7 @@ func NewThirdPartyController( // @Router /3rdparty/v1/devices [get] // // List devices. -func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) get(user users.User, c *fiber.Ctx) error { devices, err := h.devicesSvc.Select(user.ID) if err != nil { return fmt.Errorf("failed to select devices: %w", err) @@ -61,6 +63,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Summary Remove device // @Description Removes device // @Security ApiAuth +// @Security JWTAuth // @Tags User, Devices // @Produce json // @Param id path string true "Device ID" @@ -72,7 +75,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Router /3rdparty/v1/devices/{id} [delete] // // Remove device. -func (h *ThirdPartyController) remove(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) remove(user users.User, c *fiber.Ctx) error { id := c.Params("id") err := h.devicesSvc.Remove(user.ID, devices.WithID(id)) @@ -87,6 +90,6 @@ func (h *ThirdPartyController) remove(user models.User, c *fiber.Ctx) error { } func (h *ThirdPartyController) Register(router fiber.Router) { - router.Get("", userauth.WithUser(h.get)) - router.Delete(":id", userauth.WithUser(h.remove)) + router.Get("", permissions.RequireScope(ScopeList), userauth.WithUser(h.get)) + router.Delete(":id", permissions.RequireScope(ScopeDelete), userauth.WithUser(h.remove)) } diff --git a/internal/sms-gateway/handlers/devices/permissions.go b/internal/sms-gateway/handlers/devices/permissions.go new file mode 100644 index 00000000..7f4eaf9d --- /dev/null +++ b/internal/sms-gateway/handlers/devices/permissions.go @@ -0,0 +1,6 @@ +package devices + +const ( + ScopeList = "devices:list" + ScopeDelete = "devices:delete" +) diff --git a/internal/sms-gateway/handlers/logs/3rdparty.go b/internal/sms-gateway/handlers/logs/3rdparty.go index 16ec88fd..f844efda 100644 --- a/internal/sms-gateway/handlers/logs/3rdparty.go +++ b/internal/sms-gateway/handlers/logs/3rdparty.go @@ -2,8 +2,9 @@ package logs import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" - "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" "go.uber.org/fx" @@ -33,6 +34,7 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Summary Get logs // @Description Retrieve a list of log entries within a specified time range. // @Security ApiAuth +// @Security JWTAuth // @Tags System, Logs // @Produce json // @Param from query string false "The start of the time range for the logs to retrieve. Logs created after this timestamp will be included." Format(date-time) @@ -44,7 +46,7 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Router /3rdparty/v1/logs [get] // // Get logs. -func (h *ThirdPartyController) get(_ models.User, _ *fiber.Ctx) error { +func (h *ThirdPartyController) get(_ users.User, _ *fiber.Ctx) error { return fiber.NewError( fiber.StatusNotImplemented, "For privacy reasons, device's logs are not accessible through Cloud server", @@ -52,5 +54,5 @@ func (h *ThirdPartyController) get(_ models.User, _ *fiber.Ctx) error { } func (h *ThirdPartyController) Register(router fiber.Router) { - router.Get("", userauth.WithUser(h.get)) + router.Get("", permissions.RequireScope(ScopeRead), userauth.WithUser(h.get)) } diff --git a/internal/sms-gateway/handlers/logs/permissions.go b/internal/sms-gateway/handlers/logs/permissions.go new file mode 100644 index 00000000..3b8781e7 --- /dev/null +++ b/internal/sms-gateway/handlers/logs/permissions.go @@ -0,0 +1,5 @@ +package logs + +const ( + ScopeRead = "logs:read" +) diff --git a/internal/sms-gateway/handlers/messages/3rdparty.go b/internal/sms-gateway/handlers/messages/3rdparty.go index 0f609b4d..658a715c 100644 --- a/internal/sms-gateway/handlers/messages/3rdparty.go +++ b/internal/sms-gateway/handlers/messages/3rdparty.go @@ -9,10 +9,11 @@ import ( "github.com/android-sms-gateway/client-go/smsgateway" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/converters" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" - "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/messages" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/capcom6/go-helpers/slices" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" @@ -55,6 +56,7 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Summary Enqueue message // @Description Enqueues a message for sending. If `deviceId` is set, the specified device is used; otherwise a random registered device is chosen. // @Security ApiAuth +// @Security JWTAuth // @Tags User, Messages // @Accept json // @Produce json @@ -64,13 +66,14 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Success 202 {object} smsgateway.GetMessageResponse "Message enqueued" // @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 409 {object} smsgateway.ErrorResponse "Message with such ID already exists" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Header 202 {string} Location "Get message state URL" // @Router /3rdparty/v1/messages [post] // // Enqueue message. -func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) post(user users.User, c *fiber.Ctx) error { var params thirdPartyPostQueryParams if err := h.QueryParserValidator(c, ¶ms); err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) @@ -172,6 +175,7 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { // @Summary Get messages // @Description Retrieves a list of messages with filtering and pagination // @Security ApiAuth +// @Security JWTAuth // @Tags User, Messages // @Produce json // @Param from query string false "Start date in RFC3339 format" Format(date-time) @@ -183,11 +187,12 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { // @Success 200 {object} smsgateway.GetMessagesResponse "A list of messages" // @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/messages [get] // // Get message history. -func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) list(user users.User, c *fiber.Ctx) error { params := new(thirdPartyGetQueryParams) if err := h.QueryParserValidator(c, params); err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) @@ -208,17 +213,19 @@ func (h *ThirdPartyController) list(user models.User, c *fiber.Ctx) error { // @Summary Get message state // @Description Returns message state by ID // @Security ApiAuth +// @Security JWTAuth // @Tags User, Messages // @Produce json // @Param id path string true "Message ID" // @Success 200 {object} smsgateway.GetMessageResponse "Message state" // @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/messages/{id} [get] // // Get message state. -func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) get(user users.User, c *fiber.Ctx) error { id := c.Params("id") state, err := h.messagesSvc.GetState(user, id) @@ -237,6 +244,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Summary Request inbox messages export // @Description Initiates process of inbox messages export via webhooks. For each message the `sms:received` webhook will be triggered. The webhooks will be triggered without specific order. // @Security ApiAuth +// @Security JWTAuth // @Tags User, Messages // @Accept json // @Produce json @@ -244,11 +252,12 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Success 202 {object} object "Inbox export request accepted" // @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/messages/inbox/export [post] // // Export inbox. -func (h *ThirdPartyController) postInboxExport(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) postInboxExport(user users.User, c *fiber.Ctx) error { req := new(smsgateway.MessagesExportRequest) if err := h.BodyParserValidator(c, req); err != nil { return fiber.NewError(fiber.StatusBadRequest, err.Error()) @@ -315,9 +324,9 @@ func (h *ThirdPartyController) errorHandler(c *fiber.Ctx) error { func (h *ThirdPartyController) Register(router fiber.Router) { router.Use(h.errorHandler) - router.Get("", userauth.WithUser(h.list)) - router.Post("", userauth.WithUser(h.post)) - router.Get(":id", userauth.WithUser(h.get)).Name(route3rdPartyGetMessage) + router.Get("", permissions.RequireScope(ScopeList), userauth.WithUser(h.list)) + router.Post("", permissions.RequireScope(ScopeSend), userauth.WithUser(h.post)) + router.Get(":id", permissions.RequireScope(ScopeRead), userauth.WithUser(h.get)).Name(route3rdPartyGetMessage) - router.Post("inbox/export", userauth.WithUser(h.postInboxExport)) + router.Post("inbox/export", permissions.RequireScope(ScopeExport), userauth.WithUser(h.postInboxExport)) } diff --git a/internal/sms-gateway/handlers/messages/permissions.go b/internal/sms-gateway/handlers/messages/permissions.go new file mode 100644 index 00000000..0ee4a714 --- /dev/null +++ b/internal/sms-gateway/handlers/messages/permissions.go @@ -0,0 +1,13 @@ +// Package messages defines permission scopes for message-related operations. +package messages + +const ( + // ScopeSend is the permission scope required for sending messages. + ScopeSend = "messages:send" + // ScopeRead is the permission scope required for reading individual messages. + ScopeRead = "messages:read" + // ScopeList is the permission scope required for listing messages. + ScopeList = "messages:list" + // ScopeExport is the permission scope required for exporting messages. + ScopeExport = "messages:export" +) diff --git a/internal/sms-gateway/handlers/middlewares/jwtauth/jwtauth.go b/internal/sms-gateway/handlers/middlewares/jwtauth/jwtauth.go new file mode 100644 index 00000000..dabf0327 --- /dev/null +++ b/internal/sms-gateway/handlers/middlewares/jwtauth/jwtauth.go @@ -0,0 +1,42 @@ +package jwtauth + +import ( + "errors" + "strings" + + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" + "github.com/android-sms-gateway/server/internal/sms-gateway/jwt" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" + "github.com/gofiber/fiber/v2" +) + +func NewJWT(jwtSvc jwt.Service, usersSvc *users.Service) fiber.Handler { + return func(c *fiber.Ctx) error { + token := c.Get("Authorization") + + if len(token) <= 7 || !strings.EqualFold(token[:7], "Bearer ") { + return c.Next() + } + + token = token[7:] + + claims, err := jwtSvc.ParseToken(c.Context(), token) + if err != nil { + return fiber.ErrUnauthorized + } + + user, err := usersSvc.GetByID(claims.UserID) + if err != nil { + if !errors.Is(err, users.ErrNotFound) { + return fiber.ErrInternalServerError + } + return fiber.ErrUnauthorized + } + + userauth.SetUser(c, *user) + permissions.SetScopes(c, claims.Scopes) + + return c.Next() + } +} diff --git a/internal/sms-gateway/handlers/middlewares/permissions/permissions.go b/internal/sms-gateway/handlers/middlewares/permissions/permissions.go new file mode 100644 index 00000000..f7adbe92 --- /dev/null +++ b/internal/sms-gateway/handlers/middlewares/permissions/permissions.go @@ -0,0 +1,36 @@ +package permissions + +import ( + "slices" + + "github.com/gofiber/fiber/v2" +) + +const ( + ScopeAll = "all:any" + + localsScopes = "user:scopes" +) + +func SetScopes(c *fiber.Ctx, scopes []string) { + c.Locals(localsScopes, scopes) +} + +func HasScope(c *fiber.Ctx, scope string) bool { + scopes, ok := c.Locals(localsScopes).([]string) + if !ok { + return false + } + + return slices.ContainsFunc(scopes, func(item string) bool { return item == scope || item == ScopeAll }) +} + +func RequireScope(scope string) fiber.Handler { + return func(c *fiber.Ctx) error { + if !HasScope(c, scope) { + return fiber.NewError(fiber.StatusForbidden, "scope required: "+scope) + } + + return c.Next() + } +} diff --git a/internal/sms-gateway/handlers/middlewares/userauth/userauth.go b/internal/sms-gateway/handlers/middlewares/userauth/userauth.go index 0b69e3a9..73a37e1f 100644 --- a/internal/sms-gateway/handlers/middlewares/userauth/userauth.go +++ b/internal/sms-gateway/handlers/middlewares/userauth/userauth.go @@ -4,20 +4,21 @@ import ( "encoding/base64" "strings" - "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/auth" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" ) const localsUser = "user" -// NewBasic returns a middleware that will check if the request contains a valid -// "Authorization" header in the form of "Basic ". -// If the header is valid, the middleware will authorize the user and store the -// user in the request's Locals under the key LocalsUser. If the header is invalid, -// the middleware will call c.Next() and continue with the request. -func NewBasic(authSvc *auth.Service) fiber.Handler { +// NewBasic returns a middleware that optionally performs HTTP Basic authentication. +// If the "Authorization" header is missing or does not start with "Basic ", the request is passed through unchanged. +// If the header is present, the middleware expects a base64-encoded "username:password" payload, decodes it, +// validates the credentials format, and authenticates the user using the given users service. +// On invalid or failed authentication it returns 401 Unauthorized; on success it stores the user in Locals. +func NewBasic(usersSvc *users.Service) fiber.Handler { return func(c *fiber.Ctx) error { auth := c.Get(fiber.HeaderAuthorization) @@ -45,12 +46,13 @@ func NewBasic(authSvc *auth.Service) fiber.Handler { username := creds[:index] password := creds[index+1:] - user, err := authSvc.AuthorizeUser(username, password) + user, err := usersSvc.Login(c.Context(), username, password) if err != nil { return fiber.ErrUnauthorized } - c.Locals(localsUser, user) + SetUser(c, *user) + permissions.SetScopes(c, []string{permissions.ScopeAll}) return c.Next() } @@ -77,12 +79,16 @@ func NewCode(authSvc *auth.Service) fiber.Handler { return fiber.ErrUnauthorized } - c.Locals(localsUser, user) + SetUser(c, *user) return c.Next() } } +func SetUser(c *fiber.Ctx, user users.User) { + c.Locals(localsUser, user) +} + // HasUser checks if a user is present in the Locals of the given context. // It returns true if the Locals contain a user under the key LocalsUser, // otherwise returns false. @@ -90,18 +96,23 @@ func HasUser(c *fiber.Ctx) bool { return GetUser(c) != nil } -// GetUser returns the user stored in the Locals under the key LocalsUser. -func GetUser(c *fiber.Ctx) *models.User { - if user, ok := c.Locals(localsUser).(*models.User); ok { - return user +// GetUser returns the user stored in the Locals of the given context. +// It returns nil if the Locals do not contain a user under the key localsUser. +// The user is stored in Locals by the NewBasic and NewCode middlewares via SetUser, +// and is retrieved as a users.User value (exposed here as *users.User for convenience). +func GetUser(c *fiber.Ctx) *users.User { + user, ok := c.Locals(localsUser).(users.User) + if !ok { + return nil } - return nil + return &user } -// UserRequired is a middleware that ensures a user is present in the request's Locals. -// If a user is not found, it returns an unauthorized error, otherwise it passes control -// to the next handler in the stack. +// UserRequired is a middleware that checks if a user is present in the request's Locals. +// If the user is not present, it will return an unauthorized error. +// It is a convenience function that wraps the call to HasUser and calls the +// handler if the user is present. func UserRequired() fiber.Handler { return func(c *fiber.Ctx) error { if !HasUser(c) { @@ -113,11 +124,16 @@ func UserRequired() fiber.Handler { } // WithUser is a decorator that provides the current user to the handler. -func WithUser(handler func(models.User, *fiber.Ctx) error) fiber.Handler { +// It assumes that the user is stored in Locals under the key localsUser. +// If the user is not present, it returns 401 Unauthorized. +// +// It is a convenience function that wraps the call to GetUser and calls the +// handler with the user as the first argument. +func WithUser(handler func(users.User, *fiber.Ctx) error) fiber.Handler { return func(c *fiber.Ctx) error { user := GetUser(c) if user == nil { - return fiber.NewError(fiber.StatusUnauthorized, "Unauthorized") + return fiber.ErrUnauthorized } return handler(*user, c) diff --git a/internal/sms-gateway/handlers/mobile.go b/internal/sms-gateway/handlers/mobile.go index 1f53bf21..d88e572d 100644 --- a/internal/sms-gateway/handlers/mobile.go +++ b/internal/sms-gateway/handlers/mobile.go @@ -16,34 +16,20 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/auth" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/capcom6/go-helpers/anys" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/keyauth" "github.com/jaevor/go-nanoid" - "go.uber.org/fx" "go.uber.org/zap" ) -type mobileHandlerParams struct { - fx.In - - Logger *zap.Logger - Validator *validator.Validate - - AuthSvc *auth.Service - DevicesSvc *devices.Service - - MessagesCtrl *messages.MobileController - WebhooksCtrl *webhooks.MobileController - SettingsCtrl *settings.MobileController - EventsCtrl *events.MobileController -} - type mobileHandler struct { base.Handler authSvc *auth.Service + usersSvc *users.Service devicesSvc *devices.Service messagesCtrl *messages.MobileController @@ -54,6 +40,40 @@ type mobileHandler struct { idGen func() string } +func newMobileHandler( + authSvc *auth.Service, + usersSvc *users.Service, + devicesSvc *devices.Service, + + messagesCtrl *messages.MobileController, + webhooksCtrl *webhooks.MobileController, + settingsCtrl *settings.MobileController, + eventsCtrl *events.MobileController, + + logger *zap.Logger, + validator *validator.Validate, +) *mobileHandler { + const idLength = 21 + idGen, _ := nanoid.Standard(idLength) + + return &mobileHandler{ + Handler: base.Handler{ + Logger: logger, + Validator: validator, + }, + authSvc: authSvc, + usersSvc: usersSvc, + devicesSvc: devicesSvc, + + messagesCtrl: messagesCtrl, + webhooksCtrl: webhooksCtrl, + settingsCtrl: settingsCtrl, + eventsCtrl: eventsCtrl, + + idGen: idGen, + } +} + // @Summary Get device information // @Description Returns device information // @Tags Device @@ -102,26 +122,26 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) error { var ( err error - user *models.User - login string + user *users.User + username string password string ) - if userauth.HasUser(c) { - user = userauth.GetUser(c) - login = user.ID + if authUser := userauth.GetUser(c); authUser != nil { + user = authUser + username = user.ID } else { id := h.idGen() - login = strings.ToUpper(id[:6]) + username = strings.ToUpper(id[:6]) password = strings.ToLower(id[7:]) - user, err = h.authSvc.RegisterUser(login, password) + user, err = h.usersSvc.Create(username, password) if err != nil { return fmt.Errorf("failed to create user: %w", err) } } - device, err := h.authSvc.RegisterDevice(user, req.Name, req.PushToken) + device, err := h.authSvc.RegisterDevice(*user, req.Name, req.PushToken) if err != nil { return fmt.Errorf("failed to register device: %w", err) } @@ -130,7 +150,7 @@ func (h *mobileHandler) postDevice(c *fiber.Ctx) error { JSON(smsgateway.MobileRegisterResponse{ Id: device.ID, Token: device.AuthToken, - Login: login, + Login: username, Password: password, }) } @@ -178,7 +198,7 @@ func (h *mobileHandler) patchDevice(device models.Device, c *fiber.Ctx) error { // @Router /mobile/v1/user/code [get] // // Get user code. -func (h *mobileHandler) getUserCode(user models.User, c *fiber.Ctx) error { +func (h *mobileHandler) getUserCode(user users.User, c *fiber.Ctx) error { code, err := h.authSvc.GenerateUserCode(user.ID) if err != nil { h.Logger.Error("failed to generate user code", zap.Error(err), zap.String("user_id", user.ID)) @@ -212,9 +232,9 @@ func (h *mobileHandler) changePassword(device models.Device, c *fiber.Ctx) error return fiber.NewError(fiber.StatusBadRequest, err.Error()) } - if err := h.authSvc.ChangePassword(device.UserID, req.CurrentPassword, req.NewPassword); err != nil { + if err := h.usersSvc.ChangePassword(c.Context(), device.UserID, req.CurrentPassword, req.NewPassword); err != nil { h.Logger.Error("failed to change password", zap.Error(err)) - return fiber.NewError(fiber.StatusUnauthorized, "Invalid current password") + return fiber.NewError(fiber.StatusUnauthorized, "failed to change password") } return c.SendStatus(fiber.StatusNoContent) @@ -224,7 +244,7 @@ func (h *mobileHandler) Register(router fiber.Router) { router = router.Group("/mobile/v1") router.Post("/device", - userauth.NewBasic(h.authSvc), + userauth.NewBasic(h.usersSvc), userauth.NewCode(h.authSvc), keyauth.New(keyauth.Config{ Next: func(c *fiber.Ctx) bool { @@ -246,7 +266,7 @@ func (h *mobileHandler) Register(router fiber.Router) { ) router.Get("/user/code", - userauth.NewBasic(h.authSvc), + userauth.NewBasic(h.usersSvc), userauth.UserRequired(), userauth.WithUser(h.getUserCode), ) @@ -270,21 +290,3 @@ func (h *mobileHandler) Register(router fiber.Router) { h.settingsCtrl.Register(router.Group("/settings")) h.eventsCtrl.Register(router.Group("/events")) } - -func newMobileHandler(params mobileHandlerParams) *mobileHandler { - const idGenSize = 21 - idGen, _ := nanoid.Standard(idGenSize) - - return &mobileHandler{ - Handler: base.Handler{Logger: params.Logger, Validator: params.Validator}, - authSvc: params.AuthSvc, - - messagesCtrl: params.MessagesCtrl, - devicesSvc: params.DevicesSvc, - webhooksCtrl: params.WebhooksCtrl, - settingsCtrl: params.SettingsCtrl, - eventsCtrl: params.EventsCtrl, - - idGen: idGen, - } -} diff --git a/internal/sms-gateway/handlers/module.go b/internal/sms-gateway/handlers/module.go index 83440302..4be3dcb9 100644 --- a/internal/sms-gateway/handlers/module.go +++ b/internal/sms-gateway/handlers/module.go @@ -6,6 +6,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/logs" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/messages" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/settings" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/thirdparty" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/webhooks" "github.com/capcom6/go-infra-fx/http" "go.uber.org/fx" @@ -37,5 +38,6 @@ func Module() fx.Option { events.NewMobileController, fx.Private, ), + thirdparty.Module(), ) } diff --git a/internal/sms-gateway/handlers/settings/3rdparty.go b/internal/sms-gateway/handlers/settings/3rdparty.go index 6c784915..3d7b8ab5 100644 --- a/internal/sms-gateway/handlers/settings/3rdparty.go +++ b/internal/sms-gateway/handlers/settings/3rdparty.go @@ -5,10 +5,11 @@ import ( "github.com/android-sms-gateway/client-go/smsgateway" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" - "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/settings" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" "go.uber.org/fx" @@ -46,6 +47,7 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Summary Get settings // @Description Returns settings for a specific user // @Security ApiAuth +// @Security JWTAuth // @Tags User, Settings // @Produce json // @Success 200 {object} smsgateway.DeviceSettings "Settings" @@ -54,7 +56,7 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Router /3rdparty/v1/settings [get] // // Get settings. -func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) get(user users.User, c *fiber.Ctx) error { settings, err := h.settingsSvc.GetSettings(user.ID, true) if err != nil { return fmt.Errorf("failed to get settings: %w", err) @@ -66,6 +68,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Summary Replace settings // @Description Replaces settings // @Security ApiAuth +// @Security JWTAuth // @Tags User, Settings // @Accept json // @Produce json @@ -77,7 +80,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Router /3rdparty/v1/settings [put] // // Update settings. -func (h *ThirdPartyController) put(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) put(user users.User, c *fiber.Ctx) error { if err := h.BodyParserValidator(c, new(smsgateway.DeviceSettings)); err != nil { return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Invalid settings format: %v", err)) } @@ -100,6 +103,7 @@ func (h *ThirdPartyController) put(user models.User, c *fiber.Ctx) error { // @Summary Partially update settings // @Description Partially updates settings for a specific user // @Security ApiAuth +// @Security JWTAuth // @Tags User, Settings // @Accept json // @Produce json @@ -111,7 +115,7 @@ func (h *ThirdPartyController) put(user models.User, c *fiber.Ctx) error { // @Router /3rdparty/v1/settings [patch] // // Partially update settings. -func (h *ThirdPartyController) patch(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) patch(user users.User, c *fiber.Ctx) error { if err := h.BodyParserValidator(c, new(smsgateway.DeviceSettings)); err != nil { return fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("Invalid settings format: %v", err)) } @@ -131,7 +135,7 @@ func (h *ThirdPartyController) patch(user models.User, c *fiber.Ctx) error { } func (h *ThirdPartyController) Register(app fiber.Router) { - app.Get("", userauth.WithUser(h.get)) - app.Patch("", userauth.WithUser(h.patch)) - app.Put("", userauth.WithUser(h.put)) + app.Get("", permissions.RequireScope(ScopeRead), userauth.WithUser(h.get)) + app.Patch("", permissions.RequireScope(ScopeWrite), userauth.WithUser(h.patch)) + app.Put("", permissions.RequireScope(ScopeWrite), userauth.WithUser(h.put)) } diff --git a/internal/sms-gateway/handlers/settings/permissions.go b/internal/sms-gateway/handlers/settings/permissions.go new file mode 100644 index 00000000..4185fbcb --- /dev/null +++ b/internal/sms-gateway/handlers/settings/permissions.go @@ -0,0 +1,6 @@ +package settings + +const ( + ScopeRead = "settings:read" + ScopeWrite = "settings:write" +) diff --git a/internal/sms-gateway/handlers/thirdparty/auth.go b/internal/sms-gateway/handlers/thirdparty/auth.go new file mode 100644 index 00000000..e2e8d77e --- /dev/null +++ b/internal/sms-gateway/handlers/thirdparty/auth.go @@ -0,0 +1,130 @@ +package thirdparty + +import ( + "errors" + "fmt" + "time" + + "github.com/android-sms-gateway/client-go/smsgateway" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" + "github.com/android-sms-gateway/server/internal/sms-gateway/jwt" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" + "github.com/go-playground/validator/v10" + "github.com/gofiber/fiber/v2" + "go.uber.org/zap" +) + +type AuthHandler struct { + base.Handler + + jwtSvc jwt.Service +} + +func NewAuthHandler( + jwtSvc jwt.Service, + + logger *zap.Logger, + validator *validator.Validate, +) *AuthHandler { + return &AuthHandler{ + Handler: base.Handler{Logger: logger, Validator: validator}, + + jwtSvc: jwtSvc, + } +} + +func (h *AuthHandler) Register(router fiber.Router) { + router.Use(h.errorHandler) + router.Post("/token", permissions.RequireScope(ScopeTokensManage), userauth.WithUser(h.postToken)) + router.Delete("/token/:jti", permissions.RequireScope(ScopeTokensManage), userauth.WithUser(h.deleteToken)) +} + +// @Summary Generate token +// @Description Generate new access token with specified scopes and ttl +// @Security ApiAuth +// @Security JWTAuth +// @Tags User, Auth +// @Accept json +// @Produce json +// @Param request body smsgateway.TokenRequest true "Request" +// @Success 201 {object} smsgateway.TokenResponse "Token" +// @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" +// @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" +// @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" +// @Router /3rdparty/v1/auth/token [post] +// +// Generate token. +func (h *AuthHandler) postToken(user users.User, c *fiber.Ctx) error { + req := new(smsgateway.TokenRequest) + if err := h.BodyParserValidator(c, req); err != nil { + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + } + + token, err := h.jwtSvc.GenerateToken( + c.Context(), + user.ID, + req.Scopes, + time.Duration(req.TTL)*time.Second, //nolint:gosec // validated in the service + ) + if err != nil { + return fmt.Errorf("failed to generate token: %w", err) + } + + return c.Status(fiber.StatusCreated).JSON(smsgateway.TokenResponse{ + ID: token.ID, + TokenType: "Bearer", + AccessToken: token.AccessToken, + ExpiresAt: token.ExpiresAt, + }) +} + +// @Summary Revoke token +// @Description Revoke access token with specified jti +// @Security ApiAuth +// @Security JWTAuth +// @Tags User, Auth +// @Param jti path string true "JWT ID" +// @Success 204 "No Content" +// @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" +// @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" +// @Router /3rdparty/v1/auth/token/{jti} [delete] +// +// Revoke token. +func (h *AuthHandler) deleteToken(user users.User, c *fiber.Ctx) error { + jti := c.Params("jti") + + if err := h.jwtSvc.RevokeToken(c.Context(), user.ID, jti); err != nil { + return fmt.Errorf("failed to revoke token: %w", err) + } + + return c.SendStatus(fiber.StatusNoContent) +} + +func (h *AuthHandler) errorHandler(c *fiber.Ctx) error { + err := c.Next() + if err == nil { + return nil + } + + switch { + case errors.Is(err, jwt.ErrInvalidParams): + return fiber.NewError(fiber.StatusBadRequest, err.Error()) + + case errors.Is(err, jwt.ErrInitFailed): + fallthrough + case errors.Is(err, jwt.ErrInvalidConfig): + return fiber.NewError( + fiber.StatusInternalServerError, + "token service not configured, contact your administrator", + ) + + case errors.Is(err, jwt.ErrDisabled): + return fiber.NewError(fiber.StatusNotImplemented, "token service disabled, contact your administrator") + } + + return err //nolint:wrapcheck // passed through to fiber's error handler +} diff --git a/internal/sms-gateway/handlers/thirdparty/module.go b/internal/sms-gateway/handlers/thirdparty/module.go new file mode 100644 index 00000000..c45be26b --- /dev/null +++ b/internal/sms-gateway/handlers/thirdparty/module.go @@ -0,0 +1,16 @@ +package thirdparty + +import ( + "github.com/go-core-fx/logger" + "go.uber.org/fx" +) + +func Module() fx.Option { + return fx.Module( + "thirdparty", + logger.WithNamedLogger("3rdparty"), + fx.Provide( + NewAuthHandler, + ), + ) +} diff --git a/internal/sms-gateway/handlers/thirdparty/permissions.go b/internal/sms-gateway/handlers/thirdparty/permissions.go new file mode 100644 index 00000000..507ae4a6 --- /dev/null +++ b/internal/sms-gateway/handlers/thirdparty/permissions.go @@ -0,0 +1,5 @@ +package thirdparty + +const ( + ScopeTokensManage = "tokens:manage" +) diff --git a/internal/sms-gateway/handlers/webhooks/3rdparty.go b/internal/sms-gateway/handlers/webhooks/3rdparty.go index a7152c17..552aecb5 100644 --- a/internal/sms-gateway/handlers/webhooks/3rdparty.go +++ b/internal/sms-gateway/handlers/webhooks/3rdparty.go @@ -5,9 +5,10 @@ import ( "github.com/android-sms-gateway/client-go/smsgateway" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/base" + "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/permissions" "github.com/android-sms-gateway/server/internal/sms-gateway/handlers/middlewares/userauth" - "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/webhooks" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/go-playground/validator/v10" "github.com/gofiber/fiber/v2" "go.uber.org/fx" @@ -42,15 +43,17 @@ func NewThirdPartyController(params thirdPartyControllerParams) *ThirdPartyContr // @Summary List webhooks // @Description Returns list of registered webhooks // @Security ApiAuth +// @Security JWTAuth // @Tags User, Webhooks // @Produce json // @Success 200 {object} []smsgateway.Webhook "Webhook list" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/webhooks [get] // // List webhooks. -func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) get(user users.User, c *fiber.Ctx) error { items, err := h.webhooksSvc.Select(user.ID) if err != nil { return fmt.Errorf("failed to select webhooks: %w", err) @@ -62,6 +65,7 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Summary Register webhook // @Description Registers webhook. If webhook with same ID already exists, it will be replaced // @Security ApiAuth +// @Security JWTAuth // @Tags User, Webhooks // @Accept json // @Produce json @@ -69,11 +73,12 @@ func (h *ThirdPartyController) get(user models.User, c *fiber.Ctx) error { // @Success 201 {object} smsgateway.Webhook "Created" // @Failure 400 {object} smsgateway.ErrorResponse "Invalid request" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/webhooks [post] // // Register webhook. -func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) post(user users.User, c *fiber.Ctx) error { dto := new(smsgateway.Webhook) if err := h.BodyParserValidator(c, dto); err != nil { @@ -94,16 +99,18 @@ func (h *ThirdPartyController) post(user models.User, c *fiber.Ctx) error { // @Summary Delete webhook // @Description Deletes webhook // @Security ApiAuth +// @Security JWTAuth // @Tags User, Webhooks // @Produce json // @Param id path string true "Webhook ID" // @Success 204 {object} object "Webhook deleted" // @Failure 401 {object} smsgateway.ErrorResponse "Unauthorized" +// @Failure 403 {object} smsgateway.ErrorResponse "Forbidden" // @Failure 500 {object} smsgateway.ErrorResponse "Internal server error" // @Router /3rdparty/v1/webhooks/{id} [delete] // // Delete webhook. -func (h *ThirdPartyController) delete(user models.User, c *fiber.Ctx) error { +func (h *ThirdPartyController) delete(user users.User, c *fiber.Ctx) error { id := c.Params("id") if err := h.webhooksSvc.Delete(user.ID, webhooks.WithExtID(id)); err != nil { @@ -114,7 +121,7 @@ func (h *ThirdPartyController) delete(user models.User, c *fiber.Ctx) error { } func (h *ThirdPartyController) Register(router fiber.Router) { - router.Get("", userauth.WithUser(h.get)) - router.Post("", userauth.WithUser(h.post)) - router.Delete("/:id", userauth.WithUser(h.delete)) + router.Get("", permissions.RequireScope(ScopeList), userauth.WithUser(h.get)) + router.Post("", permissions.RequireScope(ScopeWrite), userauth.WithUser(h.post)) + router.Delete("/:id", permissions.RequireScope(ScopeDelete), userauth.WithUser(h.delete)) } diff --git a/internal/sms-gateway/handlers/webhooks/permissions.go b/internal/sms-gateway/handlers/webhooks/permissions.go new file mode 100644 index 00000000..3b85213b --- /dev/null +++ b/internal/sms-gateway/handlers/webhooks/permissions.go @@ -0,0 +1,7 @@ +package webhooks + +const ( + ScopeList = "webhooks:list" + ScopeWrite = "webhooks:write" + ScopeDelete = "webhooks:delete" +) diff --git a/internal/sms-gateway/jwt/config.go b/internal/sms-gateway/jwt/config.go new file mode 100644 index 00000000..d40a505b --- /dev/null +++ b/internal/sms-gateway/jwt/config.go @@ -0,0 +1,32 @@ +package jwt + +import ( + "fmt" + "time" +) + +const ( + minSecretLength = 32 +) + +type Config struct { + Secret string + TTL time.Duration + Issuer string +} + +func (c Config) Validate() error { + if c.Secret == "" { + return fmt.Errorf("%w: secret is required", ErrInvalidConfig) + } + + if len(c.Secret) < minSecretLength { + return fmt.Errorf("%w: secret must be at least %d bytes", ErrInvalidConfig, minSecretLength) + } + + if c.TTL <= 0 { + return fmt.Errorf("%w: ttl must be positive", ErrInvalidConfig) + } + + return nil +} diff --git a/internal/sms-gateway/jwt/disabled.go b/internal/sms-gateway/jwt/disabled.go new file mode 100644 index 00000000..919ae25c --- /dev/null +++ b/internal/sms-gateway/jwt/disabled.go @@ -0,0 +1,28 @@ +package jwt + +import ( + "context" + "time" +) + +type disabled struct { +} + +func newDisabled() Service { + return &disabled{} +} + +// GenerateToken implements Service. +func (d *disabled) GenerateToken(_ context.Context, _ string, _ []string, _ time.Duration) (*TokenInfo, error) { + return nil, ErrDisabled +} + +// ParseToken implements Service. +func (d *disabled) ParseToken(_ context.Context, _ string) (*Claims, error) { + return nil, ErrDisabled +} + +// RevokeToken implements Service. +func (d *disabled) RevokeToken(_ context.Context, _, _ string) error { + return ErrDisabled +} diff --git a/internal/sms-gateway/jwt/errors.go b/internal/sms-gateway/jwt/errors.go new file mode 100644 index 00000000..6f70c17a --- /dev/null +++ b/internal/sms-gateway/jwt/errors.go @@ -0,0 +1,12 @@ +package jwt + +import "errors" + +var ( + ErrDisabled = errors.New("jwt disabled") + ErrInitFailed = errors.New("failed to initialize jwt") + ErrInvalidConfig = errors.New("invalid config") + ErrInvalidParams = errors.New("invalid params") + ErrInvalidToken = errors.New("invalid token") + ErrTokenRevoked = errors.New("token revoked") +) diff --git a/internal/sms-gateway/jwt/jwt.go b/internal/sms-gateway/jwt/jwt.go new file mode 100644 index 00000000..86acffe5 --- /dev/null +++ b/internal/sms-gateway/jwt/jwt.go @@ -0,0 +1,27 @@ +package jwt + +import ( + "context" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +type Service interface { + GenerateToken(ctx context.Context, userID string, scopes []string, ttl time.Duration) (*TokenInfo, error) + ParseToken(ctx context.Context, token string) (*Claims, error) + RevokeToken(ctx context.Context, userID, jti string) error +} + +type Claims struct { + jwt.RegisteredClaims + + UserID string `json:"user_id"` + Scopes []string `json:"scopes"` +} + +type TokenInfo struct { + ID string + AccessToken string + ExpiresAt time.Time +} diff --git a/internal/sms-gateway/jwt/metrics.go b/internal/sms-gateway/jwt/metrics.go new file mode 100644 index 00000000..95ac5f1f --- /dev/null +++ b/internal/sms-gateway/jwt/metrics.go @@ -0,0 +1,117 @@ +package jwt + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Metric constants. +const ( + MetricTokensIssuedTotal = "jwt_tokens_issued_total" //nolint:gosec // false positive + MetricTokensValidatedTotal = "jwt_tokens_validated_total" //nolint:gosec // false positive + MetricTokensRevokedTotal = "jwt_tokens_revoked_total" //nolint:gosec // false positive + MetricIssuanceDurationSeconds = "jwt_issuance_duration_seconds" + MetricValidationDurationSeconds = "jwt_validation_duration_seconds" + MetricRevocationDurationSeconds = "jwt_revocation_duration_seconds" + + labelStatus = "status" + + StatusSuccess = "success" + StatusError = "error" +) + +// Metrics contains all Prometheus Metrics for the JWT module. +type Metrics struct { + tokensIssuedCounter *prometheus.CounterVec + tokensValidatedCounter *prometheus.CounterVec + tokensRevokedCounter *prometheus.CounterVec + issuanceDurationHistogram prometheus.Histogram + validationDurationHistogram prometheus.Histogram + revocationDurationHistogram prometheus.Histogram +} + +// NewMetrics creates and initializes all JWT metrics. +func NewMetrics() *Metrics { + return &Metrics{ + tokensIssuedCounter: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "sms", + Subsystem: "auth", + Name: MetricTokensIssuedTotal, + Help: "Total number of JWT tokens issued", + }, []string{labelStatus}), + + tokensValidatedCounter: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "sms", + Subsystem: "auth", + Name: MetricTokensValidatedTotal, + Help: "Total number of JWT tokens validated", + }, []string{labelStatus}), + + tokensRevokedCounter: promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "sms", + Subsystem: "auth", + Name: MetricTokensRevokedTotal, + Help: "Total number of JWT tokens revoked", + }, []string{labelStatus}), + + issuanceDurationHistogram: promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: "sms", + Subsystem: "auth", + Name: MetricIssuanceDurationSeconds, + Help: "JWT issuance duration in seconds", + Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1}, + }), + + validationDurationHistogram: promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: "sms", + Subsystem: "auth", + Name: MetricValidationDurationSeconds, + Help: "JWT validation duration in seconds", + Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1}, + }), + + revocationDurationHistogram: promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: "sms", + Subsystem: "auth", + Name: MetricRevocationDurationSeconds, + Help: "JWT revocation duration in seconds", + Buckets: []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1}, + }), + } +} + +// IncrementTokensIssued increments the tokens issued counter. +func (m *Metrics) IncrementTokensIssued(status string) { + m.tokensIssuedCounter.WithLabelValues(status).Inc() +} + +// IncrementTokensValidated increments the tokens validated counter. +func (m *Metrics) IncrementTokensValidated(status string) { + m.tokensValidatedCounter.WithLabelValues(status).Inc() +} + +// IncrementTokensRevoked increments the tokens revoked counter. +func (m *Metrics) IncrementTokensRevoked(status string) { + m.tokensRevokedCounter.WithLabelValues(status).Inc() +} + +// ObserveIssuance observes issuance duration. +func (m *Metrics) ObserveIssuance(f func()) { + timer := prometheus.NewTimer(m.issuanceDurationHistogram) + defer timer.ObserveDuration() + f() +} + +// ObserveValidation observes validation duration. +func (m *Metrics) ObserveValidation(f func()) { + timer := prometheus.NewTimer(m.validationDurationHistogram) + defer timer.ObserveDuration() + f() +} + +// ObserveRevocation observes revocation duration. +func (m *Metrics) ObserveRevocation(f func()) { + timer := prometheus.NewTimer(m.revocationDurationHistogram) + defer timer.ObserveDuration() + f() +} diff --git a/internal/sms-gateway/jwt/models.go b/internal/sms-gateway/jwt/models.go new file mode 100644 index 00000000..8f565781 --- /dev/null +++ b/internal/sms-gateway/jwt/models.go @@ -0,0 +1,38 @@ +package jwt + +import ( + "fmt" + "time" + + "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "gorm.io/gorm" +) + +type tokenModel struct { + models.TimedModel + + ID string `gorm:"primaryKey;type:char(21)"` + UserID string `gorm:"not null;type:char(21);index:idx_tokens_user_id"` + ExpiresAt time.Time `gorm:"not null;index:idx_tokens_expires_at"` + RevokedAt *time.Time +} + +func (tokenModel) TableName() string { + return "tokens" +} + +func newTokenModel(id, userID string, expiresAt time.Time) *tokenModel { + //nolint:exhaustruct // partial constructor + return &tokenModel{ + ID: id, + UserID: userID, + ExpiresAt: expiresAt, + } +} + +func Migrate(db *gorm.DB) error { + if err := db.AutoMigrate(new(tokenModel)); err != nil { + return fmt.Errorf("tokens migration failed: %w", err) + } + return nil +} diff --git a/internal/sms-gateway/jwt/module.go b/internal/sms-gateway/jwt/module.go new file mode 100644 index 00000000..7b0769ee --- /dev/null +++ b/internal/sms-gateway/jwt/module.go @@ -0,0 +1,27 @@ +package jwt + +import ( + "github.com/capcom6/go-infra-fx/db" + "github.com/go-core-fx/logger" + "go.uber.org/fx" +) + +func Module() fx.Option { + return fx.Module( + "jwt", + logger.WithNamedLogger("jwt"), + fx.Provide(NewMetrics, NewRepository, fx.Private), + fx.Provide(func(config Config, tokens *Repository, metrics *Metrics) (Service, error) { + if config.Secret == "" { + return newDisabled(), nil + } + + return New(config, tokens, metrics) + }), + ) +} + +//nolint:gochecknoinits // framework-specific +func init() { + db.RegisterMigration(Migrate) +} diff --git a/internal/sms-gateway/jwt/repository.go b/internal/sms-gateway/jwt/repository.go new file mode 100644 index 00000000..a4cfb3f5 --- /dev/null +++ b/internal/sms-gateway/jwt/repository.go @@ -0,0 +1,47 @@ +package jwt + +import ( + "context" + "fmt" + + "gorm.io/gorm" +) + +type Repository struct { + db *gorm.DB +} + +func NewRepository(db *gorm.DB) *Repository { + return &Repository{ + db: db, + } +} + +func (r *Repository) Insert(ctx context.Context, token *tokenModel) error { + if err := r.db.WithContext(ctx).Create(token).Error; err != nil { + return fmt.Errorf("can't create token: %w", err) + } + + return nil +} + +func (r *Repository) Revoke(ctx context.Context, jti, userID string) error { + if err := r.db.WithContext(ctx).Model((*tokenModel)(nil)). + Where("id = ? and user_id = ? and revoked_at is null", jti, userID). + Update("revoked_at", gorm.Expr("NOW()")).Error; err != nil { + return fmt.Errorf("can't revoke token: %w", err) + } + + return nil +} + +func (r *Repository) IsRevoked(ctx context.Context, jti string) (bool, error) { + var count int64 + if err := r.db.WithContext(ctx).Model((*tokenModel)(nil)). + Where("id = ? and revoked_at is not null", jti). + Count(&count).Error; err != nil { + return false, fmt.Errorf("can't check if token is revoked: %w", err) + } + + return count > 0, nil +} diff --git a/internal/sms-gateway/jwt/service.go b/internal/sms-gateway/jwt/service.go new file mode 100644 index 00000000..9ce5cef3 --- /dev/null +++ b/internal/sms-gateway/jwt/service.go @@ -0,0 +1,182 @@ +package jwt + +import ( + "context" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/jaevor/go-nanoid" +) + +const jtiLength = 21 + +type service struct { + config Config + + tokens *Repository + + metrics *Metrics + + idFactory func() string +} + +func New(config Config, tokens *Repository, metrics *Metrics) (Service, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + if tokens == nil { + return nil, fmt.Errorf("%w: revoked storage is required", ErrInitFailed) + } + + if metrics == nil { + return nil, fmt.Errorf("%w: metrics is required", ErrInitFailed) + } + + idFactory, err := nanoid.Standard(jtiLength) + if err != nil { + return nil, fmt.Errorf("can't create id factory: %w", err) + } + + return &service{ + config: config, + + tokens: tokens, + + metrics: metrics, + + idFactory: idFactory, + }, nil +} + +func (s *service) GenerateToken( + ctx context.Context, + userID string, + scopes []string, + ttl time.Duration, +) (*TokenInfo, error) { + var tokenInfo *TokenInfo + var err error + + s.metrics.ObserveIssuance(func() { + if userID == "" { + err = fmt.Errorf("%w: user id is required", ErrInvalidParams) + return + } + + if len(scopes) == 0 { + err = fmt.Errorf("%w: scopes are required", ErrInvalidParams) + return + } + + if ttl < 0 { + err = fmt.Errorf("%w: ttl must be non-negative", ErrInvalidParams) + return + } + + if ttl == 0 { + ttl = s.config.TTL + } + + now := time.Now() + claims := &Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: s.idFactory(), + Issuer: s.config.Issuer, + Subject: userID, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(min(ttl, s.config.TTL))), + }, + UserID: userID, + Scopes: scopes, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + signedToken, signErr := token.SignedString([]byte(s.config.Secret)) + if signErr != nil { + err = fmt.Errorf("failed to sign token: %w", signErr) + return + } + + if storeErr := s.tokens.Insert(ctx, newTokenModel(claims.ID, claims.UserID, claims.ExpiresAt.Time)); storeErr != nil { + err = fmt.Errorf("failed to insert token: %w", storeErr) + return + } + + tokenInfo = &TokenInfo{ID: claims.ID, AccessToken: signedToken, ExpiresAt: claims.ExpiresAt.Time} + }) + + if err != nil { + s.metrics.IncrementTokensIssued(StatusError) + } else { + s.metrics.IncrementTokensIssued(StatusSuccess) + } + + return tokenInfo, err +} + +func (s *service) ParseToken(ctx context.Context, token string) (*Claims, error) { + var claims *Claims + var err error + + s.metrics.ObserveValidation(func() { + parsedToken, parseErr := jwt.ParseWithClaims( + token, + new(Claims), + func(_ *jwt.Token) (any, error) { + return []byte(s.config.Secret), nil + }, + jwt.WithExpirationRequired(), + jwt.WithIssuedAt(), + jwt.WithIssuer(s.config.Issuer), + jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}), + ) + if parseErr != nil { + err = fmt.Errorf("failed to parse token: %w", parseErr) + return + } + + parsedClaims, ok := parsedToken.Claims.(*Claims) + if !ok || !parsedToken.Valid { + err = ErrInvalidToken + return + } + + revoked, parseErr := s.tokens.IsRevoked(ctx, parsedClaims.ID) + if parseErr != nil { + err = parseErr + return + } + if revoked { + err = ErrTokenRevoked + return + } + + claims = parsedClaims + }) + + if err != nil { + s.metrics.IncrementTokensValidated(StatusError) + } else { + s.metrics.IncrementTokensValidated(StatusSuccess) + } + + return claims, err +} + +func (s *service) RevokeToken(ctx context.Context, userID, jti string) error { + var err error + + s.metrics.ObserveRevocation(func() { + err = s.tokens.Revoke(ctx, jti, userID) + }) + + if err != nil { + s.metrics.IncrementTokensRevoked(StatusError) + } else { + s.metrics.IncrementTokensRevoked(StatusSuccess) + } + + return err +} diff --git a/internal/sms-gateway/models/migration.go b/internal/sms-gateway/models/migration.go index aee9173a..e4d08a58 100644 --- a/internal/sms-gateway/models/migration.go +++ b/internal/sms-gateway/models/migration.go @@ -11,7 +11,7 @@ import ( var migrations embed.FS func Migrate(db *gorm.DB) error { - if err := db.AutoMigrate(new(User), new(Device)); err != nil { + if err := db.AutoMigrate(new(Device)); err != nil { return fmt.Errorf("models migration failed: %w", err) } return nil diff --git a/internal/sms-gateway/models/migrations/mysql/20251121071748_add_tokens.sql b/internal/sms-gateway/models/migrations/mysql/20251121071748_add_tokens.sql new file mode 100644 index 00000000..fbdb6de6 --- /dev/null +++ b/internal/sms-gateway/models/migrations/mysql/20251121071748_add_tokens.sql @@ -0,0 +1,19 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE `tokens` ( + `id` char(21) NOT NULL PRIMARY KEY, + `user_id` char(21) NOT NULL, + `expires_at` datetime(3) NOT NULL, + `created_at` datetime(3) NOT NULL DEFAULT current_timestamp(3), + `updated_at` datetime(3) NOT NULL DEFAULT current_timestamp(3) ON UPDATE current_timestamp(3), + `revoked_at` datetime(3) NULL, + INDEX `idx_tokens_user_id` (`user_id`), + INDEX `idx_tokens_expires_at` (`expires_at`), + CONSTRAINT `fk_tokens_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE +); +-- +goose StatementEnd +--- +-- +goose Down +-- +goose StatementBegin +DROP TABLE `tokens`; +-- +goose StatementEnd \ No newline at end of file diff --git a/internal/sms-gateway/models/models.go b/internal/sms-gateway/models/models.go index dd28bb42..b719c33c 100644 --- a/internal/sms-gateway/models/models.go +++ b/internal/sms-gateway/models/models.go @@ -15,22 +15,6 @@ type SoftDeletableModel struct { DeletedAt *time.Time `gorm:"<-:update"` } -type User struct { - SoftDeletableModel - - ID string `gorm:"primaryKey;type:varchar(32)"` - PasswordHash string `gorm:"not null;type:varchar(72)"` - Devices []Device `gorm:"-,foreignKey:UserID;constraint:OnDelete:CASCADE"` -} - -func NewUser(id, passwordHash string) *User { - //nolint:exhaustruct // pertial constructor - return &User{ - ID: id, - PasswordHash: passwordHash, - } -} - type Device struct { SoftDeletableModel diff --git a/internal/sms-gateway/modules/auth/cache.go b/internal/sms-gateway/modules/auth/cache.go deleted file mode 100644 index e835deb8..00000000 --- a/internal/sms-gateway/modules/auth/cache.go +++ /dev/null @@ -1,55 +0,0 @@ -package auth - -import ( - "crypto/sha256" - "encoding/hex" - "fmt" - "time" - - "github.com/android-sms-gateway/server/internal/sms-gateway/models" - "github.com/capcom6/go-helpers/cache" -) - -type usersCache struct { - cache *cache.Cache[models.User] -} - -func newUsersCache() *usersCache { - return &usersCache{ - cache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}), - } -} - -func (c *usersCache) makeKey(username, password string) string { - hash := sha256.Sum256([]byte(username + "\x00" + password)) - return hex.EncodeToString(hash[:]) -} - -func (c *usersCache) Get(username, password string) (models.User, error) { - user, err := c.cache.Get(c.makeKey(username, password)) - if err != nil { - return models.User{}, fmt.Errorf("failed to get user from cache: %w", err) - } - - return user, nil -} - -func (c *usersCache) Set(username, password string, user models.User) error { - if err := c.cache.Set(c.makeKey(username, password), user); err != nil { - return fmt.Errorf("failed to cache user: %w", err) - } - - return nil -} - -func (c *usersCache) Delete(username, password string) error { - if err := c.cache.Delete(c.makeKey(username, password)); err != nil { - return fmt.Errorf("failed to delete user from cache: %w", err) - } - - return nil -} - -func (c *usersCache) Cleanup() { - c.cache.Cleanup() -} diff --git a/internal/sms-gateway/modules/auth/module.go b/internal/sms-gateway/modules/auth/module.go index 81108f6c..54cee4b3 100644 --- a/internal/sms-gateway/modules/auth/module.go +++ b/internal/sms-gateway/modules/auth/module.go @@ -14,7 +14,6 @@ func Module() fx.Option { return log.Named("auth") }), fx.Provide(New), - fx.Provide(newRepository, fx.Private), fx.Invoke(func(lc fx.Lifecycle, svc *Service) { ctx, cancel := context.WithCancel(context.Background()) lc.Append(fx.Hook{ diff --git a/internal/sms-gateway/modules/auth/repository.go b/internal/sms-gateway/modules/auth/repository.go deleted file mode 100644 index e1c8a8f2..00000000 --- a/internal/sms-gateway/modules/auth/repository.go +++ /dev/null @@ -1,37 +0,0 @@ -package auth - -import ( - "github.com/android-sms-gateway/server/internal/sms-gateway/models" - "gorm.io/gorm" -) - -type repository struct { - db *gorm.DB -} - -func newRepository(db *gorm.DB) *repository { - return &repository{ - db: db, - } -} - -// GetByID returns a user by their ID. -func (r *repository) GetByID(id string) (*models.User, error) { - user := new(models.User) - - return user, r.db.Where("id = ?", id).Take(user).Error -} - -func (r *repository) GetByLogin(login string) (*models.User, error) { - user := new(models.User) - - return user, r.db.Where("id = ?", login).Take(user).Error -} - -func (r *repository) Insert(user *models.User) error { - return r.db.Create(user).Error -} - -func (r *repository) UpdatePassword(userID string, passwordHash string) error { - return r.db.Model((*models.User)(nil)).Where("id = ?", userID).Update("password_hash", passwordHash).Error -} diff --git a/internal/sms-gateway/modules/auth/service.go b/internal/sms-gateway/modules/auth/service.go index f994d1ad..f56d0a0c 100644 --- a/internal/sms-gateway/modules/auth/service.go +++ b/internal/sms-gateway/modules/auth/service.go @@ -10,10 +10,8 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/devices" "github.com/android-sms-gateway/server/internal/sms-gateway/online" - "github.com/android-sms-gateway/server/pkg/crypto" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/capcom6/go-helpers/cache" - "github.com/jaevor/go-nanoid" - "go.uber.org/fx" "go.uber.org/zap" ) @@ -22,47 +20,35 @@ type Config struct { PrivateToken string } -type Params struct { - fx.In - - Config Config - - Users *repository - DevicesSvc *devices.Service - OnlineSvc online.Service - - Logger *zap.Logger -} - type Service struct { config Config - users *repository - codesCache *cache.Cache[string] - usersCache *usersCache - + usersSvc *users.Service devicesSvc *devices.Service onlineSvc online.Service logger *zap.Logger - idgen func() string + codesCache *cache.Cache[string] } -func New(params Params) *Service { - const idLen = 21 - idgen, _ := nanoid.Standard(idLen) - +func New( + config Config, + usersSvc *users.Service, + devicesSvc *devices.Service, + onlineSvc online.Service, + logger *zap.Logger, +) *Service { return &Service{ - config: params.Config, - users: params.Users, - devicesSvc: params.DevicesSvc, - onlineSvc: params.OnlineSvc, - logger: params.Logger, - idgen: idgen, + config: config, + + usersSvc: usersSvc, + devicesSvc: devicesSvc, + onlineSvc: onlineSvc, + + logger: logger, codesCache: cache.New[string](cache.Config{TTL: codeTTL}), - usersCache: newUsersCache(), } } @@ -96,22 +82,11 @@ func (s *Service) GenerateUserCode(userID string) (OneTimeCode, error) { return OneTimeCode{Code: code, ValidUntil: validUntil}, nil } -func (s *Service) RegisterUser(login, password string) (*models.User, error) { - passwordHash, err := crypto.MakeBCryptHash(password) - if err != nil { - return nil, fmt.Errorf("failed to hash password: %w", err) - } - - user := models.NewUser(login, passwordHash) - if err = s.users.Insert(user); err != nil { - return user, fmt.Errorf("failed to create user: %w", err) - } - - return user, nil -} - -func (s *Service) RegisterDevice(user *models.User, name, pushToken *string) (*models.Device, error) { - device := models.NewDevice(name, pushToken) +func (s *Service) RegisterDevice(user users.User, name, pushToken *string) (*models.Device, error) { + device := models.NewDevice( + name, + pushToken, + ) if err := s.devicesSvc.Insert(user.ID, device); err != nil { return device, fmt.Errorf("failed to create device: %w", err) @@ -154,69 +129,21 @@ func (s *Service) AuthorizeDevice(token string) (models.Device, error) { return device, nil } -func (s *Service) AuthorizeUser(username, password string) (*models.User, error) { - if user, err := s.usersCache.Get(username, password); err == nil { - return &user, nil - } - - user, err := s.users.GetByLogin(username) - if err != nil { - return user, err - } - - if cmpErr := crypto.CompareBCryptHash(user.PasswordHash, password); cmpErr != nil { - return nil, fmt.Errorf("password is incorrect: %w", cmpErr) - } - - if setErr := s.usersCache.Set(username, password, *user); setErr != nil { - s.logger.Error("failed to cache user", zap.Error(setErr)) - } - - return user, nil -} - // AuthorizeUserByCode authorizes a user by one-time code. -func (s *Service) AuthorizeUserByCode(code string) (*models.User, error) { +func (s *Service) AuthorizeUserByCode(code string) (*users.User, error) { userID, err := s.codesCache.GetAndDelete(code) if err != nil { return nil, fmt.Errorf("failed to get user by code: %w", err) } - user, err := s.users.GetByID(userID) + user, err := s.usersSvc.GetByID(userID) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get user: %w", err) } return user, nil } -func (s *Service) ChangePassword(userID string, currentPassword string, newPassword string) error { - user, err := s.users.GetByLogin(userID) - if err != nil { - return fmt.Errorf("failed to get user: %w", err) - } - - if hashErr := crypto.CompareBCryptHash(user.PasswordHash, currentPassword); hashErr != nil { - return fmt.Errorf("current password is incorrect: %w", hashErr) - } - - newHash, err := crypto.MakeBCryptHash(newPassword) - if err != nil { - return fmt.Errorf("failed to hash new password: %w", err) - } - - if updErr := s.users.UpdatePassword(userID, newHash); updErr != nil { - return fmt.Errorf("failed to update password: %w", updErr) - } - - // Invalidate cache - if delErr := s.usersCache.Delete(userID, currentPassword); delErr != nil { - s.logger.Error("failed to invalidate user cache", zap.Error(delErr)) - } - - return nil -} - // Run starts a ticker that triggers the clean function every hour. // It runs indefinitely until the provided context is canceled. func (s *Service) Run(ctx context.Context) { @@ -235,5 +162,4 @@ func (s *Service) Run(ctx context.Context) { func (s *Service) clean(_ context.Context) { s.codesCache.Cleanup() - s.usersCache.Cleanup() } diff --git a/internal/sms-gateway/modules/messages/service.go b/internal/sms-gateway/modules/messages/service.go index 6a4c3653..02bb1d93 100644 --- a/internal/sms-gateway/modules/messages/service.go +++ b/internal/sms-gateway/modules/messages/service.go @@ -12,6 +12,7 @@ import ( "github.com/android-sms-gateway/server/internal/sms-gateway/models" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/db" "github.com/android-sms-gateway/server/internal/sms-gateway/modules/events" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "github.com/capcom6/go-helpers/anys" "github.com/capcom6/go-helpers/slices" "github.com/nyaruka/phonenumbers" @@ -124,7 +125,7 @@ func (s *Service) UpdateState(device *models.Device, message MessageStateIn) err } func (s *Service) SelectStates( - user models.User, + user users.User, filter SelectFilter, options SelectOptions, ) ([]MessageStateOut, int64, error) { @@ -138,7 +139,7 @@ func (s *Service) SelectStates( return slices.Map(messages, modelToMessageState), total, nil } -func (s *Service) GetState(user models.User, id string) (*MessageStateOut, error) { +func (s *Service) GetState(user users.User, id string) (*MessageStateOut, error) { dto, err := s.cache.Get(context.Background(), user.ID, id) if err == nil { s.metrics.IncCache(true) diff --git a/internal/sms-gateway/modules/settings/models.go b/internal/sms-gateway/modules/settings/models.go index c547f64e..9f0d48ff 100644 --- a/internal/sms-gateway/modules/settings/models.go +++ b/internal/sms-gateway/modules/settings/models.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "gorm.io/gorm" ) @@ -13,7 +14,7 @@ type DeviceSettings struct { UserID string `gorm:"primaryKey;not null;type:varchar(32)"` Settings map[string]any `gorm:"not null;type:json;serializer:json"` - User models.User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` + User users.User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` } func NewDeviceSettings(userID string, settings map[string]any) *DeviceSettings { diff --git a/internal/sms-gateway/modules/webhooks/models.go b/internal/sms-gateway/modules/webhooks/models.go index 806bdd0d..15c7736c 100644 --- a/internal/sms-gateway/modules/webhooks/models.go +++ b/internal/sms-gateway/modules/webhooks/models.go @@ -5,6 +5,7 @@ import ( "github.com/android-sms-gateway/client-go/smsgateway" "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "github.com/android-sms-gateway/server/internal/sms-gateway/users" "gorm.io/gorm" ) @@ -20,7 +21,7 @@ type Webhook struct { URL string `json:"url" validate:"required,http_url" gorm:"not null;type:varchar(256)"` Event smsgateway.WebhookEvent `json:"event" gorm:"not null;type:varchar(32)"` - User models.User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` + User users.User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` Device *models.Device `gorm:"foreignKey:DeviceID;constraint:OnDelete:CASCADE"` } diff --git a/internal/sms-gateway/openapi/docs.go b/internal/sms-gateway/openapi/docs.go index 69a3f69c..53d3e56c 100644 --- a/internal/sms-gateway/openapi/docs.go +++ b/internal/sms-gateway/openapi/docs.go @@ -18,11 +18,131 @@ const docTemplate = `{ "host": "{{.Host}}", "basePath": "{{.BasePath}}", "paths": { + "/3rdparty/v1/auth/token": { + "post": { + "security": [ + { + "ApiAuth": [] + }, + { + "JWTAuth": [] + } + ], + "description": "Generate new access token with specified scopes and ttl", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "User", + "Auth" + ], + "summary": "Generate token", + "parameters": [ + { + "description": "Request", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/smsgateway.TokenRequest" + } + } + ], + "responses": { + "201": { + "description": "Token", + "schema": { + "$ref": "#/definitions/smsgateway.TokenResponse" + } + }, + "400": { + "description": "Invalid request", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + } + } + } + }, + "/3rdparty/v1/auth/token/{jti}": { + "delete": { + "security": [ + { + "ApiAuth": [] + }, + { + "JWTAuth": [] + } + ], + "description": "Revoke access token with specified jti", + "tags": [ + "User", + "Auth" + ], + "summary": "Revoke token", + "parameters": [ + { + "type": "string", + "description": "JWT ID", + "name": "jti", + "in": "path", + "required": true + } + ], + "responses": { + "204": { + "description": "No Content" + }, + "401": { + "description": "Unauthorized", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, + "500": { + "description": "Internal server error", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + } + } + } + }, "/3rdparty/v1/devices": { "get": { "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Returns list of registered devices", @@ -70,6 +190,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Removes device", @@ -152,6 +275,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Retrieve a list of log entries within a specified time range.", @@ -215,6 +341,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Retrieves a list of messages with filtering and pagination", @@ -290,6 +419,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "500": { "description": "Internal server error", "schema": { @@ -302,6 +437,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Enqueues a message for sending. If ` + "`" + `deviceId` + "`" + ` is set, the specified device is used; otherwise a random registered device is chosen.", @@ -366,6 +504,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "409": { "description": "Message with such ID already exists", "schema": { @@ -386,6 +530,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Initiates process of inbox messages export via webhooks. For each message the ` + "`" + `sms:received` + "`" + ` webhook will be triggered. The webhooks will be triggered without specific order.", @@ -430,6 +577,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "500": { "description": "Internal server error", "schema": { @@ -444,6 +597,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Returns message state by ID", @@ -483,6 +639,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "500": { "description": "Internal server error", "schema": { @@ -497,6 +659,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Returns settings for a specific user", @@ -533,6 +698,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Replaces settings", @@ -589,6 +757,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Partially updates settings for a specific user", @@ -647,6 +818,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Returns list of registered webhooks", @@ -674,6 +848,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "500": { "description": "Internal server error", "schema": { @@ -686,6 +866,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Registers webhook. If webhook with same ID already exists, it will be replaced", @@ -730,6 +913,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "500": { "description": "Internal server error", "schema": { @@ -744,6 +933,9 @@ const docTemplate = `{ "security": [ { "ApiAuth": [] + }, + { + "JWTAuth": [] } ], "description": "Deletes webhook", @@ -777,6 +969,12 @@ const docTemplate = `{ "$ref": "#/definitions/smsgateway.ErrorResponse" } }, + "403": { + "description": "Forbidden", + "schema": { + "$ref": "#/definitions/smsgateway.ErrorResponse" + } + }, "500": { "description": "Internal server error", "schema": { @@ -874,7 +1072,7 @@ const docTemplate = `{ ], "properties": { "data": { - "description": "Base64-encoded payload", + "description": "Data is the base64-encoded payload.", "type": "string", "format": "byte", "maxLength": 65535, @@ -882,7 +1080,7 @@ const docTemplate = `{ "example": "SGVsbG8gV29ybGQh" }, "port": { - "description": "Destination port", + "description": "Port is the destination port.", "type": "integer", "maximum": 65535, "minimum": 1, @@ -1215,7 +1413,7 @@ const docTemplate = `{ "example": true }, "message": { - "description": "Message content\nDeprecated: use TextMessage instead", + "description": "Message content (deprecated, use TextMessage instead)", "type": "string", "maxLength": 65535, "example": "Hello World!" @@ -1259,13 +1457,13 @@ const docTemplate = `{ ] }, "ttl": { - "description": "Time to live in seconds (conflicts with ` + "`" + `validUntil` + "`" + `)", + "description": "Time to live in seconds (conflicts with ` + "`" + `ValidUntil` + "`" + `)", "type": "integer", "minimum": 5, "example": 86400 }, "validUntil": { - "description": "Valid until (conflicts with ` + "`" + `ttl` + "`" + `)", + "description": "Valid until (conflicts with ` + "`" + `TTL` + "`" + `)", "type": "string", "example": "2020-01-01T00:00:00Z" }, @@ -1587,7 +1785,7 @@ const docTemplate = `{ ], "properties": { "text": { - "description": "Message text", + "description": "Text is the message text.", "type": "string", "maxLength": 65535, "minLength": 1, @@ -1595,6 +1793,47 @@ const docTemplate = `{ } } }, + "smsgateway.TokenRequest": { + "type": "object", + "required": [ + "scopes" + ], + "properties": { + "scopes": { + "description": "scopes for which the access token is valid", + "type": "array", + "minItems": 1, + "items": { + "type": "string" + } + }, + "ttl": { + "description": "lifetime of the access token in seconds", + "type": "integer" + } + } + }, + "smsgateway.TokenResponse": { + "type": "object", + "properties": { + "access_token": { + "description": "actual access token", + "type": "string" + }, + "expires_at": { + "description": "time at which the access token is no longer valid", + "type": "string" + }, + "id": { + "description": "unique identifier for the access token", + "type": "string" + }, + "token_type": { + "description": "type of the access token", + "type": "string" + } + } + }, "smsgateway.Webhook": { "type": "object", "required": [ @@ -1633,22 +1872,40 @@ const docTemplate = `{ "smsgateway.WebhookEvent": { "type": "string", "enum": [ - "sms:received", + "mms:received", "sms:data-received", - "sms:sent", "sms:delivered", "sms:failed", - "system:ping", - "mms:received" + "sms:received", + "sms:sent", + "system:ping" + ], + "x-enum-comments": { + "WebhookEventMmsReceived": "Triggered when an MMS is received.", + "WebhookEventSmsDataReceived": "Triggered when a data SMS is received.", + "WebhookEventSmsDelivered": "Triggered when an SMS is delivered.", + "WebhookEventSmsFailed": "Triggered when an SMS processing fails.", + "WebhookEventSmsReceived": "Triggered when an SMS is received.", + "WebhookEventSmsSent": "Triggered when an SMS is sent.", + "WebhookEventSystemPing": "Triggered when the device pings the server." + }, + "x-enum-descriptions": [ + "Triggered when an MMS is received.", + "Triggered when a data SMS is received.", + "Triggered when an SMS is delivered.", + "Triggered when an SMS processing fails.", + "Triggered when an SMS is received.", + "Triggered when an SMS is sent.", + "Triggered when the device pings the server." ], "x-enum-varnames": [ - "WebhookEventSmsReceived", + "WebhookEventMmsReceived", "WebhookEventSmsDataReceived", - "WebhookEventSmsSent", "WebhookEventSmsDelivered", "WebhookEventSmsFailed", - "WebhookEventSystemPing", - "WebhookEventMmsReceived" + "WebhookEventSmsReceived", + "WebhookEventSmsSent", + "WebhookEventSystemPing" ] } }, @@ -1656,6 +1913,12 @@ const docTemplate = `{ "ApiAuth": { "type": "basic" }, + "JWTAuth": { + "description": "JWT authentication", + "type": "apiKey", + "name": "Authorization", + "in": "header" + }, "MobileToken": { "description": "Mobile device token", "type": "apiKey", diff --git a/internal/sms-gateway/users/cache.go b/internal/sms-gateway/users/cache.go new file mode 100644 index 00000000..4bd042be --- /dev/null +++ b/internal/sms-gateway/users/cache.go @@ -0,0 +1,90 @@ +package users + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/android-sms-gateway/server/pkg/cache" +) + +const loginCacheTTL = time.Hour + +type loginCacheWrapper struct { + ID string `json:"id"` + + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (w *loginCacheWrapper) Unmarshal(data []byte) error { + if err := json.Unmarshal(data, w); err != nil { + return fmt.Errorf("failed to unmarshal login cache wrapper: %w", err) + } + + return nil +} + +func (w *loginCacheWrapper) Marshal() ([]byte, error) { + data, err := json.Marshal(w) + if err != nil { + return nil, fmt.Errorf("failed to marshal login cache wrapper: %w", err) + } + + return data, nil +} + +type loginCache struct { + storage *cache.Typed[*loginCacheWrapper] +} + +func newLoginCache(storage cache.Cache) *loginCache { + return &loginCache{ + storage: cache.NewTyped[*loginCacheWrapper](storage), + } +} + +func (c *loginCache) makeKey(username, password string) string { + hash := sha256.Sum256([]byte(username + "\x00" + password)) + return hex.EncodeToString(hash[:]) +} + +func (c *loginCache) Get(ctx context.Context, username, password string) (*User, error) { + user, err := c.storage.Get(ctx, c.makeKey(username, password), cache.AndSetTTL(loginCacheTTL)) + if err != nil { + return nil, fmt.Errorf("failed to get user from cache: %w", err) + } + + return &User{ + ID: user.ID, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + }, nil +} + +func (c *loginCache) Set(ctx context.Context, username, password string, user User) error { + wrapper := &loginCacheWrapper{ + ID: user.ID, + CreatedAt: user.CreatedAt, + UpdatedAt: user.UpdatedAt, + } + + if err := c.storage.Set(ctx, c.makeKey(username, password), wrapper, cache.WithTTL(loginCacheTTL)); err != nil { + return fmt.Errorf("failed to cache user: %w", err) + } + + return nil +} + +func (c *loginCache) Delete(ctx context.Context, username, password string) error { + err := c.storage.Delete(ctx, c.makeKey(username, password)) + if err == nil || errors.Is(err, cache.ErrKeyNotFound) || errors.Is(err, cache.ErrKeyExpired) { + return nil + } + + return fmt.Errorf("failed to delete user from cache: %w", err) +} diff --git a/internal/sms-gateway/users/domain.go b/internal/sms-gateway/users/domain.go new file mode 100644 index 00000000..3ca85727 --- /dev/null +++ b/internal/sms-gateway/users/domain.go @@ -0,0 +1,19 @@ +package users + +import "time" + +type User struct { + ID string + + CreatedAt time.Time + UpdatedAt time.Time +} + +func newUser(model *userModel) *User { + return &User{ + ID: model.ID, + + CreatedAt: model.CreatedAt, + UpdatedAt: model.UpdatedAt, + } +} diff --git a/internal/sms-gateway/users/errors.go b/internal/sms-gateway/users/errors.go new file mode 100644 index 00000000..dc8de6a1 --- /dev/null +++ b/internal/sms-gateway/users/errors.go @@ -0,0 +1,8 @@ +package users + +import "errors" + +var ( + ErrNotFound = errors.New("user not found") + ErrExists = errors.New("user already exists") +) diff --git a/internal/sms-gateway/users/models.go b/internal/sms-gateway/users/models.go new file mode 100644 index 00000000..ceb9662e --- /dev/null +++ b/internal/sms-gateway/users/models.go @@ -0,0 +1,34 @@ +package users + +import ( + "fmt" + + "github.com/android-sms-gateway/server/internal/sms-gateway/models" + "gorm.io/gorm" +) + +type userModel struct { + models.SoftDeletableModel + + ID string `gorm:"primaryKey;type:varchar(32)"` + PasswordHash string `gorm:"not null;type:varchar(72)"` +} + +func newUserModel(id string, passwordHash string) *userModel { + //nolint:exhaustruct // partial constructor + return &userModel{ + ID: id, + PasswordHash: passwordHash, + } +} + +func (u *userModel) TableName() string { + return "users" +} + +func Migrate(db *gorm.DB) error { + if err := db.AutoMigrate(new(userModel)); err != nil { + return fmt.Errorf("users migration failed: %w", err) + } + return nil +} diff --git a/internal/sms-gateway/users/module.go b/internal/sms-gateway/users/module.go new file mode 100644 index 00000000..19d518c8 --- /dev/null +++ b/internal/sms-gateway/users/module.go @@ -0,0 +1,32 @@ +package users + +import ( + "fmt" + + "github.com/android-sms-gateway/server/internal/sms-gateway/cache" + "github.com/capcom6/go-infra-fx/db" + "github.com/go-core-fx/logger" + "go.uber.org/fx" +) + +func Module() fx.Option { + return fx.Module( + "users", + logger.WithNamedLogger("users"), + fx.Provide(func(factory cache.Factory) (*loginCache, error) { + storage, err := factory.New("users:login") + if err != nil { + return nil, fmt.Errorf("can't create login cache: %w", err) + } + + return newLoginCache(storage), nil + }, fx.Private), + fx.Provide(newRepository, fx.Private), + fx.Provide(NewService), + ) +} + +//nolint:gochecknoinits // framework-specific +func init() { + db.RegisterMigration(Migrate) +} diff --git a/internal/sms-gateway/users/repository.go b/internal/sms-gateway/users/repository.go new file mode 100644 index 00000000..49b8adb5 --- /dev/null +++ b/internal/sms-gateway/users/repository.go @@ -0,0 +1,67 @@ +package users + +import ( + "errors" + "fmt" + + "github.com/android-sms-gateway/server/pkg/mysql" + "gorm.io/gorm" +) + +type repository struct { + db *gorm.DB +} + +// newRepository creates a new repository instance. +func newRepository(db *gorm.DB) *repository { + return &repository{ + db: db, + } +} + +func (r *repository) Exists(id string) (bool, error) { + var count int64 + if err := r.db.Model((*userModel)(nil)). + Where("id = ?", id). + Count(&count).Error; err != nil { + return false, fmt.Errorf("can't check if user exists: %w", err) + } + + return count > 0, nil +} + +// GetByID retrieves a user by their ID. +func (r *repository) GetByID(id string) (*userModel, error) { + user := new(userModel) + + if err := r.db.Where("id = ?", id).Take(user).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrNotFound + } + return nil, fmt.Errorf("can't get user: %w", err) + } + + return user, nil +} + +func (r *repository) Insert(user *userModel) error { + if err := r.db.Create(user).Error; err != nil { + if mysql.IsDuplicateKeyViolation(err) { + return fmt.Errorf("%w: %w", ErrExists, err) + } + + return fmt.Errorf("can't create user: %w", err) + } + + return nil +} + +func (r *repository) UpdatePassword(id string, passwordHash string) error { + if err := r.db.Model((*userModel)(nil)). + Where("id = ?", id). + Update("password_hash", passwordHash).Error; err != nil { + return fmt.Errorf("can't update password: %w", err) + } + + return nil +} diff --git a/internal/sms-gateway/users/service.go b/internal/sms-gateway/users/service.go new file mode 100644 index 00000000..1ff8cd8c --- /dev/null +++ b/internal/sms-gateway/users/service.go @@ -0,0 +1,100 @@ +package users + +import ( + "context" + "errors" + "fmt" + + "github.com/android-sms-gateway/server/pkg/cache" + "github.com/android-sms-gateway/server/pkg/crypto" + "go.uber.org/zap" +) + +type Service struct { + users *repository + + cache *loginCache + + logger *zap.Logger +} + +func NewService( + users *repository, + cache *loginCache, + logger *zap.Logger, +) *Service { + return &Service{ + users: users, + + cache: cache, + + logger: logger, + } +} + +func (s *Service) Create(username, password string) (*User, error) { + passwordHash, err := crypto.MakeBCryptHash(password) + if err != nil { + return nil, fmt.Errorf("failed to hash password: %w", err) + } + + user := newUserModel(username, passwordHash) + + if insErr := s.users.Insert(user); insErr != nil { + return nil, fmt.Errorf("failed to create user: %w", insErr) + } + + return newUser(user), nil +} + +func (s *Service) GetByID(id string) (*User, error) { + user, err := s.users.GetByID(id) + if err != nil { + return nil, err + } + + return newUser(user), nil +} + +func (s *Service) Login(ctx context.Context, username, password string) (*User, error) { + cachedUser, err := s.cache.Get(ctx, username, password) + if err == nil { + return cachedUser, nil + } else if !errors.Is(err, cache.ErrKeyNotFound) { + s.logger.Warn("failed to get user from cache", zap.String("username", username), zap.Error(err)) + } + + user, err := s.users.GetByID(username) + if err != nil { + return nil, err + } + + if compErr := crypto.CompareBCryptHash(user.PasswordHash, password); compErr != nil { + return nil, fmt.Errorf("login failed: %w", compErr) + } + + loggedInUser := newUser(user) + if setErr := s.cache.Set(ctx, username, password, *loggedInUser); setErr != nil { + s.logger.Error("failed to cache user", zap.String("username", username), zap.Error(setErr)) + } + + return loggedInUser, nil +} + +func (s *Service) ChangePassword(ctx context.Context, username, currentPassword, newPassword string) error { + _, err := s.Login(ctx, username, currentPassword) + if err != nil { + return err + } + + if delErr := s.cache.Delete(ctx, username, currentPassword); delErr != nil { + return delErr + } + + passwordHash, err := crypto.MakeBCryptHash(newPassword) + if err != nil { + return fmt.Errorf("failed to hash password: %w", err) + } + + return s.users.UpdatePassword(username, passwordHash) +} diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index e50e2898..05e5ca44 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -1,40 +1,283 @@ +// Package cache provides a flexible caching abstraction with multiple backend implementations. +// +// The cache package offers a unified interface for caching operations with support for: +// - In-memory caching (memoryCache) +// - Redis-based caching (redisCache) +// - Type-safe caching through generics (Typed[T]) +// +// Features: +// - Time-to-live (TTL) support for cache entries +// - Atomic operations for consistency +// - Concurrent-safe implementations +// - Configurable expiration policies +// - Support for custom serialization through the Item interface +// +// Basic Usage: +// +// // Create an in-memory cache with 1 hour default TTL +// cache := cache.NewMemory(time.Hour) +// +// // Set a value +// err := cache.Set(ctx, "key", []byte("value")) +// if err != nil { +// log.Fatal(err) +// } +// +// // Get a value +// value, err := cache.Get(ctx, "key") +// if err != nil { +// if errors.Is(err, cache.ErrKeyNotFound) { +// // Handle missing key +// } else if errors.Is(err, cache.ErrKeyExpired) { +// // Handle expired key +// } else { +// log.Fatal(err) +// } +// } +// +// // Set with custom TTL +// err = cache.Set(ctx, "key", []byte("value"), cache.WithTTL(30*time.Minute)) +// +// // Set only if key doesn't exist +// err = cache.SetOrFail(ctx, "key", []byte("value")) +// if errors.Is(err, cache.ErrKeyExists) { +// // Key already exists +// } +// +// // Get and delete in one operation +// value, err = cache.GetAndDelete(ctx, "key") +// +// // Remove expired entries +// err = cache.Cleanup(ctx) +// +// // Get all items and clear the cache +// items, err := cache.Drain(ctx) +// +// // Close the cache when done +// err = cache.Close() +// +// Using Typed Cache: +// +// // Define a type that implements the Item interface +// type MyData struct { +// Field1 string +// Field2 int +// } +// +// func (d *MyData) Marshal() ([]byte, error) { +// return json.Marshal(d) +// } +// +// func (d *MyData) Unmarshal(data []byte) error { +// return json.Unmarshal(data, d) +// } +// +// // Create a typed cache +// storage := cache.NewMemory(time.Hour) +// typedCache := cache.NewTyped[*MyData](storage) +// +// // Set typed value +// data := &MyData{Field1: "test", Field2: 42} +// err := typedCache.Set(ctx, "key", data) +// +// // Get typed value +// retrieved, err := typedCache.Get(ctx, "key") +// +// Using Redis Cache: +// +// // Create a Redis cache +// config := cache.RedisConfig{ +// URL: "redis://localhost:6379", +// Prefix: "myapp:", +// TTL: time.Hour, +// } +// +// redisCache, err := cache.NewRedis(config) +// if err != nil { +// log.Fatal(err) +// } +// defer redisCache.Close() +// +// // Use the same interface as memory cache +// err = redisCache.Set(ctx, "key", []byte("value")) +// value, err := redisCache.Get(ctx, "key") package cache import "context" +// Cache defines the interface for cache implementations. +// +// All cache operations are context-aware and support cancellation and timeouts. +// Implementations must be safe for concurrent use by multiple goroutines. type Cache interface { - // Set sets the value for the given key in the cache. + // Set stores the value for the given key in the cache, overwriting any existing value. + // + // The value will be stored with the default TTL configured for the cache implementation, + // unless overridden by options. If the key already exists, its value and TTL will be updated. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // - key: The key to store the value under + // - value: The value to store as a byte slice + // - opts: Optional configuration for this specific item (e.g., custom TTL) + // + // Returns: + // - error: nil on success, otherwise an error describing the failure + // + // Example: + // // Set with default TTL + // err := cache.Set(ctx, "user:123", []byte("user data")) + // + // // Set with custom TTL + // err := cache.Set(ctx, "session:abc", []byte("session data"), cache.WithTTL(30*time.Minute)) + // + // // Set with specific expiration time + // expiration := time.Now().Add(2 * time.Hour) + // err := cache.Set(ctx, "temp:xyz", []byte("temp data"), cache.WithValidUntil(expiration)) Set(ctx context.Context, key string, value []byte, opts ...Option) error - // SetOrFail is like Set, but returns ErrKeyExists if the key already exists. + // SetOrFail stores the value for the given key only if the key does not already exist. + // + // This is an atomic operation that prevents race conditions when multiple goroutines + // might try to set the same key simultaneously. If the key exists but has expired, + // it will be overwritten. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // - key: The key to store the value under + // - value: The value to store as a byte slice + // - opts: Optional configuration for this specific item (e.g., custom TTL) + // + // Returns: + // - error: nil on success, ErrKeyExists if the key already exists and is not expired, + // otherwise an error describing the failure + // + // Example: + // // Try to set a value only if key doesn't exist + // err := cache.SetOrFail(ctx, "lock:resource", []byte("locked")) + // if errors.Is(err, cache.ErrKeyExists) { + // // Key already exists, handle conflict + // } SetOrFail(ctx context.Context, key string, value []byte, opts ...Option) error - // Get gets the value for the given key from the cache. + // Get retrieves the value for the given key from the cache. + // + // The behavior depends on the key's existence and expiration state: + // - If the key exists and has not expired, returns the value and nil error + // - If the key does not exist, returns nil and ErrKeyNotFound + // - If the key exists but has expired, returns nil and ErrKeyExpired + // + // GetOptions can be used to modify the behavior, such as updating TTL, + // deleting the key after retrieval, or setting a new expiration time. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // - key: The key to retrieve + // - opts: Optional operations to perform during retrieval (e.g., AndDelete, AndSetTTL) + // + // Returns: + // - []byte: The cached value if found and not expired + // - error: nil on success, ErrKeyNotFound if key doesn't exist, + // ErrKeyExpired if key exists but has expired, otherwise an error // - // If the key is not found, it returns ErrKeyNotFound. - // If the key has expired, it returns ErrKeyExpired. - // Otherwise, it returns the value and nil. + // Example: + // // Simple get + // value, err := cache.Get(ctx, "user:123") + // + // // Get and extend TTL by 30 minutes + // value, err := cache.Get(ctx, "session:abc", cache.AndUpdateTTL(30*time.Minute)) + // + // // Get and delete atomically + // value, err := cache.Get(ctx, "temp:xyz", cache.AndDelete()) Get(ctx context.Context, key string, opts ...GetOption) ([]byte, error) - // GetAndDelete is like Get, but also deletes the key from the cache. + // GetAndDelete retrieves the value for the given key and atomically deletes it from the cache. + // + // This is equivalent to calling Get with the AndDelete option, but provides a more + // convenient API for the common pattern of reading and removing a value in one operation. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // - key: The key to retrieve and delete + // + // Returns: + // - []byte: The cached value if found and not expired + // - error: nil on success, ErrKeyNotFound if key doesn't exist, + // ErrKeyExpired if key exists but has expired, otherwise an error + // + // Example: + // // Atomically get and remove a value + // value, err := cache.GetAndDelete(ctx, "queue:item:123") GetAndDelete(ctx context.Context, key string) ([]byte, error) // Delete removes the item associated with the given key from the cache. - // If the key does not exist, it performs no action and returns nil. - // The operation is safe for concurrent use. + // + // If the key does not exist, this operation performs no action and returns nil. + // The operation is safe for concurrent use by multiple goroutines. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // - key: The key to remove from the cache + // + // Returns: + // - error: nil on success, otherwise an error describing the failure + // + // Example: + // // Remove a specific key + // err := cache.Delete(ctx, "user:123") Delete(ctx context.Context, key string) error // Cleanup removes all expired items from the cache. - // The operation is safe for concurrent use. + // + // This operation scans the entire cache and removes any items that have expired. + // The operation is safe for concurrent use by multiple goroutines. + // Note that some cache implementations (like Redis) handle expiration automatically + // and may not require explicit cleanup. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // + // Returns: + // - error: nil on success, otherwise an error describing the failure + // + // Example: + // // Periodically clean up expired items + // err := cache.Cleanup(ctx) Cleanup(ctx context.Context) error - // Drain returns a map of all the non-expired items in the cache. + // Drain returns a map of all non-expired items in the cache and clears the cache. + // // The returned map is a snapshot of the cache at the time of the call. - // The cache is cleared after the call. - // The operation is safe for concurrent use. + // After this operation, the cache will be empty. This is useful for cache migration, + // backup, or when shutting down an application. + // The operation is safe for concurrent use by multiple goroutines. + // + // Parameters: + // - ctx: Context for cancellation and timeouts + // + // Returns: + // - map[string][]byte: A map containing all non-expired key-value pairs + // - error: nil on success, otherwise an error describing the failure + // + // Example: + // // Get all items and clear the cache + // items, err := cache.Drain(ctx) + // for key, value := range items { + // log.Printf("Drained: %s = %s", key, string(value)) + // } Drain(ctx context.Context) (map[string][]byte, error) - // Close closes the cache. - // The operation is safe for concurrent use. + // Close releases any resources held by the cache. + // + // This should be called when the cache is no longer needed. For some implementations + // (like Redis), this may close network connections. The operation is safe for + // concurrent use by multiple goroutines. + // + // Returns: + // - error: nil on success, otherwise an error describing the failure + // + // Example: + // // Properly close the cache when done + // defer cache.Close() Close() error } diff --git a/pkg/cache/errors.go b/pkg/cache/errors.go index 0ca10cf5..dbb6d863 100644 --- a/pkg/cache/errors.go +++ b/pkg/cache/errors.go @@ -6,9 +6,49 @@ var ( // ErrInvalidConfig indicates an invalid configuration. ErrInvalidConfig = errors.New("invalid config") // ErrKeyNotFound indicates no value exists for the given key. + // + // This error is returned by Get operations when the requested key has never been + // set in the cache or has been explicitly deleted. It is also returned by + // GetAndDelete when the key does not exist. + // + // Example: + // _, err := cache.Get(ctx, "nonexistent-key") + // if errors.Is(err, cache.ErrKeyNotFound) { + // // Handle missing key + // } ErrKeyNotFound = errors.New("key not found") + // ErrKeyExpired indicates a value exists but has expired. + // + // This error is returned by Get operations when the requested key exists in the + // cache but its time-to-live (TTL) has elapsed. Expired items may still exist + // in the cache until they are explicitly removed by a Cleanup operation or + // automatically by the cache implementation. + // + // Example: + // // Set a value with 1 second TTL + // cache.Set(ctx, "temp-key", []byte("data"), cache.WithTTL(time.Second)) + // time.Sleep(2 * time.Second) + // _, err := cache.Get(ctx, "temp-key") + // if errors.Is(err, cache.ErrKeyExpired) { + // // Handle expired key + // } ErrKeyExpired = errors.New("key expired") + // ErrKeyExists indicates a conflicting set when the key already exists. + // + // This error is returned by SetOrFail operations when attempting to set a value + // for a key that already exists in the cache and has not expired. This is useful + // for implementing atomic "create if not exists" operations and preventing + // race conditions in concurrent scenarios. + // + // Example: + // // Try to set a value only if key doesn't exist + // err := cache.SetOrFail(ctx, "lock-key", []byte("locked")) + // if errors.Is(err, cache.ErrKeyExists) { + // // Key already exists, handle conflict + // } ErrKeyExists = errors.New("key already exists") + + ErrFailedToCreateZeroValue = errors.New("failed to create zero item") ) diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index 5e2c51d6..a94f0b9b 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -6,6 +6,20 @@ import ( "time" ) +// MemoryCache implements an in-memory cache with TTL support. +// +// This implementation stores all data in a Go map protected by a read-write mutex, +// making it safe for concurrent access by multiple goroutines. Items are automatically +// checked for expiration on access, but expired items remain in memory until +// explicitly removed by a Cleanup operation or overwritten. +// +// The memory cache is suitable for: +// - Single-process applications +// - Caching small to medium amounts of data +// - Scenarios where low latency is critical +// - Temporary caching that doesn't need persistence +// +// For distributed caching or persistence, consider using the Redis implementation instead. type MemoryCache struct { items map[string]*memoryItem ttl time.Duration @@ -13,6 +27,25 @@ type MemoryCache struct { mux sync.RWMutex } +// NewMemory creates a new in-memory cache with the specified default TTL. +// +// The TTL parameter sets the default time-to-live for items stored in the cache. +// A TTL of zero means items do not expire by default, but individual items +// can still have their own TTL set via options. +// +// Parameters: +// - ttl: The default TTL for cache items. Zero means no expiration. +// +// Returns: +// - *MemoryCache: A new in-memory cache instance +// +// Example: +// +// // Create a cache with 1 hour default TTL +// cache := cache.NewMemory(time.Hour) +// +// // Create a cache with no expiration +// cache := cache.NewMemory(0) func NewMemory(ttl time.Duration) *MemoryCache { return &MemoryCache{ items: make(map[string]*memoryItem), @@ -22,12 +55,14 @@ func NewMemory(ttl time.Duration) *MemoryCache { } } +// memoryItem represents a single item in the memory cache. type memoryItem struct { value []byte validUntil time.Time } -func newItem(value []byte, opts options) *memoryItem { +// newMemoryItem creates a new memory item with the specified value and options. +func newMemoryItem(value []byte, opts options) *memoryItem { item := &memoryItem{ value: value, validUntil: opts.validUntil, @@ -36,6 +71,17 @@ func newItem(value []byte, opts options) *memoryItem { return item } +// isExpired checks if the item has expired at the given time. +// +// An item is considered expired if: +// - The item is nil (safety check) +// - The item has a non-zero validUntil time and the current time is after validUntil +// +// Parameters: +// - now: The time to check expiration against +// +// Returns: +// - bool: True if the item has expired, false otherwise func (i *memoryItem) isExpired(now time.Time) bool { if i == nil { return true @@ -44,14 +90,45 @@ func (i *memoryItem) isExpired(now time.Time) bool { return !i.validUntil.IsZero() && now.After(i.validUntil) } -// Cleanup implements Cache. +// Cleanup removes all expired items from the memory cache. +// +// This method scans through all items in the cache and removes any that have expired. +// The operation is performed atomically with a write lock to ensure consistency. +// Note that this is a manual cleanup operation - expired items are also checked +// during normal Get operations, but this method explicitly removes them from memory. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// +// Returns: +// - error: Always nil for memory cache +// +// Example: +// +// // Periodically clean up expired items +// err := cache.Cleanup(ctx) func (m *MemoryCache) Cleanup(_ context.Context) error { m.cleanup(func() {}) return nil } -// Delete implements Cache. +// Delete removes the item associated with the given key from the memory cache. +// +// If the key does not exist, this operation performs no action and returns nil. +// The operation is safe for concurrent use by multiple goroutines. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// - key: The key to remove from the cache +// +// Returns: +// - error: Always nil for memory cache +// +// Example: +// +// // Remove a specific key +// err := cache.Delete(ctx, "user:123") func (m *MemoryCache) Delete(_ context.Context, key string) error { m.mux.Lock() delete(m.items, key) @@ -60,7 +137,27 @@ func (m *MemoryCache) Delete(_ context.Context, key string) error { return nil } -// Drain implements Cache. +// Drain returns a map of all non-expired items in the memory cache and clears the cache. +// +// The returned map is a snapshot of the cache at the time of the call. +// After this operation, the cache will be empty. This is useful for cache migration, +// backup, or when shutting down an application. The operation is performed atomically +// with a write lock to ensure consistency. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// +// Returns: +// - map[string][]byte: A map containing all non-expired key-value pairs +// - error: Always nil for memory cache +// +// Example: +// +// // Get all items and clear the cache +// items, err := cache.Drain(ctx) +// for key, value := range items { +// log.Printf("Drained: %s = %s", key, string(value)) +// } func (m *MemoryCache) Drain(_ context.Context) (map[string][]byte, error) { var cpy map[string]*memoryItem @@ -77,7 +174,36 @@ func (m *MemoryCache) Drain(_ context.Context) (map[string][]byte, error) { return items, nil } -// Get implements Cache. +// Get retrieves the value for the given key from the memory cache. +// +// The behavior depends on the key's existence and expiration state: +// - If the key exists and has not expired, returns the value and nil error +// - If the key does not exist, returns nil and ErrKeyNotFound +// - If the key exists but has expired, returns nil and ErrKeyExpired +// +// GetOptions can be used to modify the behavior, such as updating TTL, +// deleting the key after retrieval, or setting a new expiration time. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// - key: The key to retrieve +// - opts: Optional operations to perform during retrieval (e.g., AndDelete, AndSetTTL) +// +// Returns: +// - []byte: The cached value if found and not expired +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired, otherwise an error +// +// Example: +// +// // Simple get +// value, err := cache.Get(ctx, "user:123") +// +// // Get and extend TTL by 30 minutes +// value, err := cache.Get(ctx, "session:abc", cache.AndUpdateTTL(30*time.Minute)) +// +// // Get and delete atomically +// value, err := cache.Get(ctx, "temp:xyz", cache.AndDelete()) func (m *MemoryCache) Get(_ context.Context, key string, opts ...GetOption) ([]byte, error) { return m.getValue(func() (*memoryItem, bool) { if len(opts) == 0 { @@ -92,34 +218,84 @@ func (m *MemoryCache) Get(_ context.Context, key string, opts ...GetOption) ([]b o.apply(opts...) m.mux.Lock() + defer m.mux.Unlock() item, ok := m.items[key] - if ok && o.delete { + if !ok || item.isExpired(time.Now()) { + return item, ok + } + + if o.delete { delete(m.items, key) - } else if ok && !item.isExpired(time.Now()) { - switch { - case o.validUntil != nil: - item.validUntil = *o.validUntil - case o.setTTL != nil: - item.validUntil = time.Now().Add(*o.setTTL) - case o.updateTTL != nil: + return item, ok + } + + switch { + case o.validUntil != nil: + item.validUntil = *o.validUntil + case o.setTTL != nil: + item.validUntil = time.Now().Add(*o.setTTL) + case o.updateTTL != nil: + if item.validUntil.IsZero() { + item.validUntil = time.Now().Add(*o.updateTTL) + } else { item.validUntil = item.validUntil.Add(*o.updateTTL) - case o.defaultTTL: + } + case o.defaultTTL: + if m.ttl > 0 { item.validUntil = time.Now().Add(m.ttl) + } else { + item.validUntil = time.Time{} } } - m.mux.Unlock() return item, ok }) } -// GetAndDelete implements Cache. +// GetAndDelete retrieves the value for the given key and atomically deletes it from the memory cache. +// +// This is equivalent to calling Get with the AndDelete option, but provides a more +// convenient API for the common pattern of reading and removing a value in one operation. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to retrieve and delete +// +// Returns: +// - []byte: The cached value if found and not expired +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired, otherwise an error +// +// Example: +// +// // Atomically get and remove a value +// value, err := cache.GetAndDelete(ctx, "queue:item:123") func (m *MemoryCache) GetAndDelete(ctx context.Context, key string) ([]byte, error) { return m.Get(ctx, key, AndDelete()) } -// Set implements Cache. +// Set stores the value for the given key in the memory cache, overwriting any existing value. +// +// The value will be stored with the default TTL configured for the cache, +// unless overridden by options. If the key already exists, its value and TTL will be updated. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// - key: The key to store the value under +// - value: The value to store as a byte slice +// - opts: Optional configuration for this specific item (e.g., custom TTL) +// +// Returns: +// - error: Always nil for memory cache +// +// Example: +// +// // Set with default TTL +// err := cache.Set(ctx, "user:123", []byte("user data")) +// +// // Set with custom TTL +// err := cache.Set(ctx, "session:abc", []byte("session data"), cache.WithTTL(30*time.Minute)) func (m *MemoryCache) Set(_ context.Context, key string, value []byte, opts ...Option) error { m.mux.Lock() m.items[key] = m.newItem(value, opts...) @@ -128,7 +304,29 @@ func (m *MemoryCache) Set(_ context.Context, key string, value []byte, opts ...O return nil } -// SetOrFail implements Cache. +// SetOrFail stores the value for the given key only if the key does not already exist. +// +// This is an atomic operation that prevents race conditions when multiple goroutines +// might try to set the same key simultaneously. If the key exists but has expired, +// it will be overwritten. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// - key: The key to store the value under +// - value: The value to store as a byte slice +// - opts: Optional configuration for this specific item (e.g., custom TTL) +// +// Returns: +// - error: nil on success, ErrKeyExists if the key already exists and is not expired, +// otherwise an error +// +// Example: +// +// // Try to set a value only if key doesn't exist +// err := cache.SetOrFail(ctx, "lock:resource", []byte("locked")) +// if errors.Is(err, cache.ErrKeyExists) { +// // Key already exists, handle conflict +// } func (m *MemoryCache) SetOrFail(_ context.Context, key string, value []byte, opts ...Option) error { m.mux.Lock() defer m.mux.Unlock() @@ -143,6 +341,17 @@ func (m *MemoryCache) SetOrFail(_ context.Context, key string, value []byte, opt return nil } +// newItem creates a new memory item with the specified value and options. +// +// This method applies the cache's default TTL if no explicit expiration is set +// in the options, and then creates a new memoryItem with the combined configuration. +// +// Parameters: +// - value: The value to store in the item +// - opts: Optional configuration for the item (e.g., custom TTL) +// +// Returns: +// - *memoryItem: A new memory item with the specified configuration func (m *MemoryCache) newItem(value []byte, opts ...Option) *memoryItem { o := options{ validUntil: time.Time{}, @@ -152,9 +361,21 @@ func (m *MemoryCache) newItem(value []byte, opts ...Option) *memoryItem { } o.apply(opts...) - return newItem(value, o) + return newMemoryItem(value, o) } +// getItem retrieves a memory item using the provided getter function and checks for expiration. +// +// This is a helper method that wraps the getter function and adds expiration checking. +// It returns ErrKeyNotFound if the item doesn't exist, or ErrKeyExpired if it exists but has expired. +// +// Parameters: +// - getter: A function that returns a memory item and a boolean indicating if it was found +// +// Returns: +// - *memoryItem: The memory item if found and not expired +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired func (m *MemoryCache) getItem(getter func() (*memoryItem, bool)) (*memoryItem, error) { item, ok := getter() @@ -169,6 +390,18 @@ func (m *MemoryCache) getItem(getter func() (*memoryItem, bool)) (*memoryItem, e return item, nil } +// getValue retrieves the value of a memory item using the provided getter function. +// +// This is a helper method that uses getItem to get the memory item and then +// extracts the byte slice value from it. +// +// Parameters: +// - getter: A function that returns a memory item and a boolean indicating if it was found +// +// Returns: +// - []byte: The cached value if found and not expired +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired func (m *MemoryCache) getValue(getter func() (*memoryItem, bool)) ([]byte, error) { item, err := m.getItem(getter) if err != nil { @@ -178,6 +411,14 @@ func (m *MemoryCache) getValue(getter func() (*memoryItem, bool)) ([]byte, error return item.value, nil } +// cleanup removes all expired items from the memory cache and executes the provided callback. +// +// This method scans through all items in the cache and removes any that have expired. +// The callback function is executed while holding the write lock, allowing for +// atomic operations during cleanup. +// +// Parameters: +// - cb: A callback function to execute during cleanup (while holding the write lock) func (m *MemoryCache) cleanup(cb func()) { t := time.Now() @@ -192,8 +433,21 @@ func (m *MemoryCache) cleanup(cb func()) { m.mux.Unlock() } +// Close releases any resources held by the memory cache. +// +// For the memory cache implementation, this is a no-op since there are no +// external resources to clean up. The method is included for interface compatibility. +// +// Returns: +// - error: Always nil for memory cache +// +// Example: +// +// // Properly close the cache when done +// defer cache.Close() func (m *MemoryCache) Close() error { return nil } +// Compile-time check to ensure memoryCache implements the Cache interface. var _ Cache = (*MemoryCache)(nil) diff --git a/pkg/cache/options.go b/pkg/cache/options.go index 20417f50..4f569051 100644 --- a/pkg/cache/options.go +++ b/pkg/cache/options.go @@ -3,20 +3,39 @@ package cache import "time" // Option configures per-item cache behavior (e.g., expiry). +// +// Options are used with Set and SetOrFail operations to customize how individual +// cache items are stored, including their expiration time and TTL. type Option func(*options) +// options holds the configuration for a cache item. type options struct { validUntil time.Time } +// apply applies the given options to this options struct. func (o *options) apply(opts ...Option) { for _, opt := range opts { opt(o) } } -// WithTTL is an Option that sets the TTL (time to live) for an item, i.e. the -// item will expire after the given duration from the time of insertion. +// WithTTL sets the TTL (time to live) for an item. +// +// The item will expire after the given duration from the time of insertion. +// A TTL of zero means the item will not expire. +// A negative TTL means the item expires immediately. +// +// Parameters: +// - ttl: The duration after which the item will expire +// +// Returns: +// - Option: An option that sets the TTL when passed to Set or SetOrFail +// +// Example: +// +// // Set a value that expires in 30 minutes +// err := cache.Set(ctx, "session:abc", []byte("data"), cache.WithTTL(30*time.Minute)) func WithTTL(ttl time.Duration) Option { return func(o *options) { switch { @@ -30,14 +49,31 @@ func WithTTL(ttl time.Duration) Option { } } -// WithValidUntil is an Option that sets the valid until time for an item, i.e. -// the item will expire at the given time. +// WithValidUntil sets the exact expiration time for an item. +// +// The item will expire at the given time, regardless of when it was inserted. +// This is useful when you need precise control over when an item expires, +// such as at midnight or the end of a billing period. +// +// Parameters: +// - validUntil: The exact time when the item should expire +// +// Returns: +// - Option: An option that sets the expiration time when passed to Set or SetOrFail +// +// Example: +// +// // Set a value that expires at midnight +// midnight := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Local) +// midnight = midnight.Add(24 * time.Hour) // Next midnight +// err := cache.Set(ctx, "daily:report", []byte("data"), cache.WithValidUntil(midnight)) func WithValidUntil(validUntil time.Time) Option { return func(o *options) { o.validUntil = validUntil } } +// getOptions holds the configuration for Get operations. type getOptions struct { validUntil *time.Time setTTL *time.Duration @@ -46,14 +82,21 @@ type getOptions struct { delete bool } +// GetOption configures the behavior of Get operations. +// +// GetOptions allow you to perform additional operations during a Get call, +// such as updating the item's TTL, setting a new expiration time, or +// deleting the item after retrieval. type GetOption func(*getOptions) +// apply applies the given GetOptions to this getOptions struct. func (o *getOptions) apply(opts ...GetOption) { for _, opt := range opts { opt(o) } } +// isEmpty returns true if no GetOptions are set. func (o *getOptions) isEmpty() bool { return o.validUntil == nil && o.setTTL == nil && @@ -62,30 +105,103 @@ func (o *getOptions) isEmpty() bool { !o.delete } +// AndSetTTL sets a new TTL for the item during a Get operation. +// +// This option is useful for extending the lifetime of an item when it's accessed, +// implementing a "touch" behavior where frequently accessed items remain in cache longer. +// +// Parameters: +// - ttl: The new TTL duration to set for the item +// +// Returns: +// - GetOption: An option that sets the TTL when passed to Get +// +// Example: +// +// // Get a value and extend its TTL to 30 minutes from now +// value, err := cache.Get(ctx, "session:abc", cache.AndSetTTL(30*time.Minute)) func AndSetTTL(ttl time.Duration) GetOption { return func(o *getOptions) { o.setTTL = &ttl } } +// AndUpdateTTL extends the current TTL of an item by the given duration. +// +// Unlike AndSetTTL which sets an absolute TTL from the current time, +// AndUpdateTTL adds the specified duration to the item's existing TTL. +// This is useful for incrementally extending the lifetime of an item. +// +// Parameters: +// - ttl: The duration to add to the item's current TTL +// +// Returns: +// - GetOption: An option that extends the TTL when passed to Get +// +// Example: +// +// // Get a value and extend its TTL by 15 minutes +// value, err := cache.Get(ctx, "session:abc", cache.AndUpdateTTL(15*time.Minute)) func AndUpdateTTL(ttl time.Duration) GetOption { return func(o *getOptions) { o.updateTTL = &ttl } } +// AndSetValidUntil sets a new exact expiration time for the item during a Get operation. +// +// This option allows you to set a precise expiration time for an item when it's accessed, +// which can be useful for implementing time-based access patterns. +// +// Parameters: +// - validUntil: The new exact expiration time for the item +// +// Returns: +// - GetOption: An option that sets the expiration time when passed to Get +// +// Example: +// +// // Get a value and set it to expire at midnight +// midnight := time.Date(time.Now().Year(), time.Now().Month(), time.Now().Day(), 0, 0, 0, 0, time.Local) +// midnight = midnight.Add(24 * time.Hour) // Next midnight +// value, err := cache.Get(ctx, "daily:report", cache.AndSetValidUntil(midnight)) func AndSetValidUntil(validUntil time.Time) GetOption { return func(o *getOptions) { o.validUntil = &validUntil } } +// AndDefaultTTL resets the item's TTL to the cache's default TTL during a Get operation. +// +// This option is useful when you want to restore an item to the default expiration +// policy of the cache, regardless of its current TTL. +// +// Returns: +// - GetOption: An option that resets the TTL to the cache default when passed to Get +// +// Example: +// +// // Get a value and reset its TTL to the cache default +// value, err := cache.Get(ctx, "session:abc", cache.AndDefaultTTL()) func AndDefaultTTL() GetOption { return func(o *getOptions) { o.defaultTTL = true } } +// AndDelete deletes the item from the cache during a Get operation. +// +// This option provides an atomic "get and delete" operation, which is useful +// for implementing queue-like behavior or ensuring that an item is only +// processed once. This is equivalent to calling GetAndDelete. +// +// Returns: +// - GetOption: An option that deletes the item when passed to Get +// +// Example: +// +// // Get a value and remove it from the cache +// value, err := cache.Get(ctx, "queue:item:123", cache.AndDelete()) func AndDelete() GetOption { return func(o *getOptions) { o.delete = true diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index f441d878..eefdd4d8 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -38,14 +38,14 @@ if deleteFlag then end if ttlTs > 0 then - redis.call('HExpireAt', KEYS[1], ttlTs, field) + redis.call('HExpireAt', KEYS[1], ttlTs, 'FIELDS', '1', field) elseif ttlDelta > 0 then - local ttl = redis.call('HTTL', KEYS[1], field) + local ttl = redis.call('HTTL', KEYS[1], 'FIELDS', '1', field) if ttl < 0 then ttl = 0 end local newTtl = ttl + ttlDelta - redis.call('HExpire', KEYS[1], newTtl, field) + redis.call('HExpire', KEYS[1], newTtl, 'FIELDS', '1', field) end return value @@ -53,6 +53,10 @@ return value ) // RedisConfig configures the Redis cache backend. +// +// This struct provides configuration options for creating a Redis-based cache +// implementation. You can either provide an existing Redis client or let the +// cache create one from a URL. type RedisConfig struct { // Client is the Redis client to use. // If nil, a client is created from the URL. @@ -63,12 +67,28 @@ type RedisConfig struct { URL string // Prefix is the prefix to use for all keys in the Redis cache. + // This helps avoid key collisions when multiple applications use the same Redis instance. Prefix string // TTL is the time-to-live for all cache entries. + // This is the default TTL used when no explicit TTL is provided. TTL time.Duration } +// RedisCache implements the Cache interface using Redis as the backend. +// +// This implementation stores all data in a Redis hash, with each cache item +// being a field in the hash. It uses Redis's built-in TTL functionality for +// expiration and Lua scripts for atomic operations. +// +// The Redis cache is suitable for: +// - Distributed applications where multiple processes need access to the same cache +// - Caching large amounts of data that don't fit in memory +// - Scenarios where cache persistence is required +// - High-availability caching with Redis clustering +// +// For single-process applications or when low latency is critical, consider using +// the in-memory implementation instead. type RedisCache struct { client *redis.Client ownedClient bool @@ -78,6 +98,41 @@ type RedisCache struct { ttl time.Duration } +// NewRedis creates a new Redis cache with the specified configuration. +// +// This function validates the configuration and creates a Redis client if one +// is not provided. The key used for storing cache items in Redis is constructed +// from the prefix and a constant suffix. +// +// Parameters: +// - config: Configuration for the Redis cache +// +// Returns: +// - *redisCache: A new Redis cache instance +// - error: An error if the configuration is invalid or the Redis client cannot be created +// +// Example: +// +// // Create a Redis cache with a new client +// config := cache.RedisConfig{ +// URL: "redis://localhost:6379", +// Prefix: "myapp:", +// TTL: time.Hour, +// } +// redisCache, err := cache.NewRedis(config) +// if err != nil { +// log.Fatal(err) +// } +// defer redisCache.Close() +// +// // Create a Redis cache with an existing client +// client := redis.NewClient(&redis.Options{Addr: "localhost:6379"}) +// config := cache.RedisConfig{ +// Client: client, +// Prefix: "myapp:", +// TTL: time.Hour, +// } +// redisCache, err := cache.NewRedis(config) func NewRedis(config RedisConfig) (*RedisCache, error) { if config.Prefix != "" && !strings.HasSuffix(config.Prefix, ":") { config.Prefix += ":" @@ -107,12 +162,42 @@ func NewRedis(config RedisConfig) (*RedisCache, error) { }, nil } -// Cleanup implements Cache. +// Cleanup removes all expired items from the Redis cache. +// +// For Redis cache implementation, this is a no-op because Redis handles +// expiration automatically. Expired items are automatically removed by Redis +// based on their TTL settings. +// +// Parameters: +// - ctx: Context for cancellation and timeouts (currently unused but kept for interface compatibility) +// +// Returns: +// - error: Always nil for Redis cache +// +// Example: +// +// // No-op for Redis cache, but included for interface compatibility +// err := redisCache.Cleanup(ctx) func (r *RedisCache) Cleanup(_ context.Context) error { return nil } -// Delete implements Cache. +// Delete removes the item associated with the given key from the Redis cache. +// +// If the key does not exist, this operation performs no action and returns nil. +// The operation uses Redis's HDEL command to remove the field from the hash. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to remove from the cache +// +// Returns: +// - error: nil on success, or an error if the Redis operation fails +// +// Example: +// +// // Remove a specific key +// err := redisCache.Delete(ctx, "user:123") func (r *RedisCache) Delete(ctx context.Context, key string) error { if err := r.client.HDel(ctx, r.key, key).Err(); err != nil { return fmt.Errorf("failed to delete cache item: %w", err) @@ -121,7 +206,26 @@ func (r *RedisCache) Delete(ctx context.Context, key string) error { return nil } -// Drain implements Cache. +// Drain returns a map of all non-expired items in the Redis cache and clears the cache. +// +// This operation uses a Lua script to atomically get all fields from the hash +// and then delete the entire hash. This ensures that the operation is atomic +// and no items are lost between the get and delete operations. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// +// Returns: +// - map[string][]byte: A map containing all non-expired key-value pairs +// - error: nil on success, or an error if the Redis operation fails +// +// Example: +// +// // Get all items and clear the cache +// items, err := redisCache.Drain(ctx) +// for key, value := range items { +// log.Printf("Drained: %s = %s", key, string(value)) +// } func (r *RedisCache) Drain(ctx context.Context) (map[string][]byte, error) { res, err := r.client.Eval(ctx, hgetallAndDeleteScript, []string{r.key}).Result() if err != nil { @@ -144,7 +248,37 @@ func (r *RedisCache) Drain(ctx context.Context) (map[string][]byte, error) { return out, nil } -// Get implements Cache. +// Get retrieves the value for the given key from the Redis cache. +// +// The behavior depends on the key's existence and expiration state: +// - If the key exists and has not expired, returns the value and nil error +// - If the key does not exist, returns nil and ErrKeyNotFound +// - If the key exists but has expired, returns nil and ErrKeyExpired +// +// GetOptions can be used to modify the behavior, such as updating TTL, +// deleting the key after retrieval, or setting a new expiration time. +// When GetOptions are provided, a Lua script is used for atomic operations. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to retrieve +// - opts: Optional operations to perform during retrieval (e.g., AndDelete, AndSetTTL) +// +// Returns: +// - []byte: The cached value if found and not expired +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired, otherwise an error +// +// Example: +// +// // Simple get +// value, err := redisCache.Get(ctx, "user:123") +// +// // Get and extend TTL by 30 minutes +// value, err := redisCache.Get(ctx, "session:abc", cache.AndUpdateTTL(30*time.Minute)) +// +// // Get and delete atomically +// value, err := redisCache.Get(ctx, "temp:xyz", cache.AndDelete()) func (r *RedisCache) Get(ctx context.Context, key string, opts ...GetOption) ([]byte, error) { o := new(getOptions) o.apply(opts...) @@ -195,12 +329,50 @@ func (r *RedisCache) Get(ctx context.Context, key string, opts ...GetOption) ([] return nil, ErrKeyNotFound } -// GetAndDelete implements Cache. +// GetAndDelete retrieves the value for the given key and atomically deletes it from the Redis cache. +// +// This is equivalent to calling Get with the AndDelete option, but provides a more +// convenient API for the common pattern of reading and removing a value in one operation. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to retrieve and delete +// +// Returns: +// - []byte: The cached value if found and not expired +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired, otherwise an error +// +// Example: +// +// // Atomically get and remove a value +// value, err := redisCache.GetAndDelete(ctx, "queue:item:123") func (r *RedisCache) GetAndDelete(ctx context.Context, key string) ([]byte, error) { return r.Get(ctx, key, AndDelete()) } -// Set implements Cache. +// Set stores the value for the given key in the Redis cache, overwriting any existing value. +// +// The value will be stored with the default TTL configured for the cache, +// unless overridden by options. If the key already exists, its value and TTL will be updated. +// This method uses Redis pipelining to ensure that the set and TTL operations are atomic. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to store the value under +// - value: The value to store as a byte slice +// - opts: Optional configuration for this specific item (e.g., custom TTL) +// +// Returns: +// - error: nil on success, or an error if the Redis operation fails +// +// Example: +// +// // Set with default TTL +// err := redisCache.Set(ctx, "user:123", []byte("user data")) +// +// // Set with custom TTL +// err := redisCache.Set(ctx, "session:abc", []byte("session data"), cache.WithTTL(30*time.Minute)) func (r *RedisCache) Set(ctx context.Context, key string, value []byte, opts ...Option) error { options := new(options) if r.ttl > 0 { @@ -222,7 +394,29 @@ func (r *RedisCache) Set(ctx context.Context, key string, value []byte, opts ... return nil } -// SetOrFail implements Cache. +// SetOrFail stores the value for the given key only if the key does not already exist. +// +// This is an atomic operation that prevents race conditions when multiple goroutines +// might try to set the same key simultaneously. It uses Redis's HSetNX command +// which only sets the field if it doesn't already exist. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to store the value under +// - value: The value to store as a byte slice +// - opts: Optional configuration for this specific item (e.g., custom TTL) +// +// Returns: +// - error: nil on success, ErrKeyExists if the key already exists and is not expired, +// otherwise an error +// +// Example: +// +// // Try to set a value only if key doesn't exist +// err := redisCache.SetOrFail(ctx, "lock:resource", []byte("locked")) +// if errors.Is(err, cache.ErrKeyExists) { +// // Key already exists, handle conflict +// } func (r *RedisCache) SetOrFail(ctx context.Context, key string, value []byte, opts ...Option) error { val, err := r.client.HSetNX(ctx, r.key, key, value).Result() if err != nil { @@ -248,6 +442,19 @@ func (r *RedisCache) SetOrFail(ctx context.Context, key string, value []byte, op return nil } +// Close releases any resources held by the Redis cache. +// +// If the Redis cache was created with a client (rather than using an existing client), +// this method will close the Redis client connection. If an existing client was provided, +// this method is a no-op to avoid closing a client that might be used elsewhere. +// +// Returns: +// - error: nil on success, or an error if closing the Redis client fails +// +// Example: +// +// // Properly close the cache when done +// defer redisCache.Close() func (r *RedisCache) Close() error { if r.ownedClient { if err := r.client.Close(); err != nil { @@ -258,4 +465,5 @@ func (r *RedisCache) Close() error { return nil } +// Compile-time check to ensure redisCache implements the Cache interface. var _ Cache = (*RedisCache)(nil) diff --git a/pkg/cache/typed.go b/pkg/cache/typed.go new file mode 100644 index 00000000..0c86601c --- /dev/null +++ b/pkg/cache/typed.go @@ -0,0 +1,407 @@ +package cache + +import ( + "context" + "fmt" + "reflect" +) + +// Item defines the interface that types must implement to be used with Typed cache. +// +// Types that implement this interface can be automatically serialized and +// deserialized when stored in or retrieved from the cache. This allows +// for type-safe caching of complex data structures. +// +// Example implementation: +// +// type User struct { +// ID string +// Name string +// } +// +// func (u *User) Marshal() ([]byte, error) { +// return json.Marshal(u) +// } +// +// func (u *User) Unmarshal(data []byte) error { +// return json.Unmarshal(data, u) +// } +type Item interface { + // Marshal converts the item to a byte slice for storage in the cache. + // + // This method is called when the item is stored in the cache. + // Common implementations include JSON, protobuf, or other serialization formats. + // + // Returns: + // - []byte: The serialized representation of the item + // - error: An error if serialization fails + Marshal() ([]byte, error) + + // Unmarshal populates the item from a byte slice retrieved from the cache. + // + // This method is called when the item is retrieved from the cache. + // It should restore the item's state from the serialized data. + // + // Parameters: + // - data: The serialized representation of the item + // + // Returns: + // - error: An error if deserialization fails + Unmarshal(data []byte) error +} + +// Typed provides a type-safe wrapper around a Cache implementation. +// +// This generic wrapper allows caching of specific types that implement the Item +// interface, providing compile-time type safety and eliminating the need for +// manual type assertions when working with cached values. +// +// The Typed wrapper handles serialization and deserialization automatically, +// making it easy to cache complex data structures while maintaining type safety. +// +// Example usage: +// +// // Define a type that implements Item +// type User struct { +// ID string +// Name string +// } +// +// func (u *User) Marshal() ([]byte, error) { +// return json.Marshal(u) +// } +// +// func (u *User) Unmarshal(data []byte) error { +// return json.Unmarshal(data, u) +// } +// +// // Create a typed cache +// storage := cache.NewMemory(time.Hour) +// userCache := cache.NewTyped[*User](storage) +// +// // Set a typed value +// user := &User{ID: "123", Name: "Alice"} +// err := userCache.Set(ctx, "user:123", user) +// +// // Get a typed value +// retrieved, err := userCache.Get(ctx, "user:123") +// // retrieved is of type *User, no type assertion needed +type Typed[T Item] struct { + storage Cache +} + +// NewTyped creates a new typed cache wrapper around the provided storage. +// +// The typed cache uses the underlying storage for all operations but adds +// automatic serialization and deserialization of values that implement the Item interface. +// +// Parameters: +// - storage: The underlying cache implementation to wrap +// +// Returns: +// - *Typed[T]: A new typed cache wrapper +// +// Example: +// +// // Create a typed cache with in-memory storage +// storage := cache.NewMemory(time.Hour) +// userCache := cache.NewTyped[*User](storage) +// +// // Create a typed cache with Redis storage +// config := cache.RedisConfig{URL: "redis://localhost:6379"} +// redisCache, err := cache.NewRedis(config) +// if err != nil { +// log.Fatal(err) +// } +// defer redisCache.Close() +// userCache := cache.NewTyped[*User](redisCache) +func NewTyped[T Item](storage Cache) *Typed[T] { + return &Typed[T]{ + storage: storage, + } +} + +// Set stores the typed value for the given key in the cache, overwriting any existing value. +// +// The value will be automatically marshaled to bytes before storage using its +// Marshal method. The value will be stored with the default TTL configured for +// the cache implementation, unless overridden by options. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to store the value under +// - value: The typed value to store +// - opts: Optional configuration for this specific item (e.g., custom TTL) +// +// Returns: +// - error: nil on success, or an error if marshaling fails or the cache operation fails +// +// Example: +// +// // Set with default TTL +// user := &User{ID: "123", Name: "Alice"} +// err := userCache.Set(ctx, "user:123", user) +// +// // Set with custom TTL +// err := userCache.Set(ctx, "session:abc", user, cache.WithTTL(30*time.Minute)) +func (c *Typed[T]) Set(ctx context.Context, key string, value T, opts ...Option) error { + data, err := value.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal value: %w", err) + } + + if setErr := c.storage.Set(ctx, key, data, opts...); setErr != nil { + return fmt.Errorf("failed to set value in cache: %w", setErr) + } + + return nil +} + +// SetOrFail stores the typed value for the given key only if the key does not already exist. +// +// This is an atomic operation that prevents race conditions when multiple goroutines +// might try to set the same key simultaneously. The value will be automatically +// marshaled to bytes before storage using its Marshal method. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to store the value under +// - value: The typed value to store +// - opts: Optional configuration for this specific item (e.g., custom TTL) +// +// Returns: +// - error: nil on success, ErrKeyExists if the key already exists and is not expired, +// otherwise an error if marshaling fails or the cache operation fails +// +// Example: +// +// // Try to set a value only if key doesn't exist +// user := &User{ID: "123", Name: "Alice"} +// err := userCache.SetOrFail(ctx, "user:123", user) +// if errors.Is(err, cache.ErrKeyExists) { +// // Key already exists, handle conflict +// } +func (c *Typed[T]) SetOrFail(ctx context.Context, key string, value T, opts ...Option) error { + data, err := value.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal value: %w", err) + } + + if setErr := c.storage.SetOrFail(ctx, key, data, opts...); setErr != nil { + return fmt.Errorf("failed to set value in cache: %w", setErr) + } + + return nil +} + +// Get retrieves the typed value for the given key from the cache. +// +// The behavior depends on the key's existence and expiration state: +// - If the key exists and has not expired, returns the typed value and nil error +// - If the key does not exist, returns a zero value and ErrKeyNotFound +// - If the key exists but has expired, returns a zero value and ErrKeyExpired +// +// The retrieved bytes are automatically unmarshaled to the correct type using +// the Unmarshal method. GetOptions can be used to modify the behavior, such as +// updating TTL, deleting the key after retrieval, or setting a new expiration time. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to retrieve +// - opts: Optional operations to perform during retrieval (e.g., AndDelete, AndSetTTL) +// +// Returns: +// - T: The cached typed value if found and not expired, otherwise a zero value +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired, otherwise an error +// +// Example: +// +// // Simple get +// user, err := userCache.Get(ctx, "user:123") +// +// // Get and extend TTL by 30 minutes +// user, err := userCache.Get(ctx, "session:abc", cache.AndUpdateTTL(30*time.Minute)) +// +// // Get and delete atomically +// user, err := userCache.Get(ctx, "temp:xyz", cache.AndDelete()) +func (c *Typed[T]) Get(ctx context.Context, key string, opts ...GetOption) (T, error) { + data, err := c.storage.Get(ctx, key, opts...) + var zero T + if err != nil { + return zero, fmt.Errorf("failed to get value from cache: %w", err) + } + + value, err := newItem[T]() + if err != nil { + return zero, err + } + if unmarshalErr := value.Unmarshal(data); unmarshalErr != nil { + return zero, fmt.Errorf("failed to unmarshal value from cache: %w", unmarshalErr) + } + + return value, nil +} + +// GetAndDelete retrieves the typed value for the given key and atomically deletes it from the cache. +// +// This is equivalent to calling Get with the AndDelete option, but provides a more +// convenient API for the common pattern of reading and removing a value in one operation. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to retrieve and delete +// +// Returns: +// - T: The cached typed value if found and not expired, otherwise a zero value +// - error: nil on success, ErrKeyNotFound if key doesn't exist, +// ErrKeyExpired if key exists but has expired, otherwise an error +// +// Example: +// +// // Atomically get and remove a value +// user, err := userCache.GetAndDelete(ctx, "queue:item:123") +func (c *Typed[T]) GetAndDelete(ctx context.Context, key string) (T, error) { + return c.Get(ctx, key, AndDelete()) +} + +// Delete removes the item associated with the given key from the cache. +// +// If the key does not exist, this operation performs no action and returns nil. +// The operation is safe for concurrent use by multiple goroutines. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// - key: The key to remove from the cache +// +// Returns: +// - error: nil on success, or an error if the cache operation fails +// +// Example: +// +// // Remove a specific key +// err := userCache.Delete(ctx, "user:123") +func (c *Typed[T]) Delete(ctx context.Context, key string) error { + if err := c.storage.Delete(ctx, key); err != nil { + return fmt.Errorf("failed to delete value from cache: %w", err) + } + + return nil +} + +// Cleanup removes all expired items from the cache. +// +// This operation scans the entire cache and removes any items that have expired. +// The operation is safe for concurrent use by multiple goroutines. +// Note that some cache implementations (like Redis) handle expiration automatically +// and may not require explicit cleanup. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// +// Returns: +// - error: nil on success, or an error if the cache operation fails +// +// Example: +// +// // Periodically clean up expired items +// err := userCache.Cleanup(ctx) +func (c *Typed[T]) Cleanup(ctx context.Context) error { + if err := c.storage.Cleanup(ctx); err != nil { + return fmt.Errorf("failed to cleanup cache: %w", err) + } + + return nil +} + +// Drain returns a map of all non-expired typed items in the cache and clears the cache. +// +// The returned map is a snapshot of the cache at the time of the call. +// After this operation, the cache will be empty. This is useful for cache migration, +// backup, or when shutting down an application. The operation is safe for +// concurrent use by multiple goroutines. +// +// Parameters: +// - ctx: Context for cancellation and timeouts +// +// Returns: +// - map[string]T: A map containing all non-expired key-typed value pairs +// - error: nil on success, or an error if the cache operation fails +// +// Example: +// +// // Get all items and clear the cache +// users, err := userCache.Drain(ctx) +// for key, user := range users { +// log.Printf("Drained: %s = %+v", key, user) +// } +func (c *Typed[T]) Drain(ctx context.Context) (map[string]T, error) { + data, err := c.storage.Drain(ctx) + if err != nil { + return nil, fmt.Errorf("failed to drain cache: %w", err) + } + + items := make(map[string]T, len(data)) + for key, raw := range data { + item, newErr := newItem[T]() + if newErr != nil { + return nil, newErr + } + if unmarshalErr := item.Unmarshal(raw); unmarshalErr != nil { + return nil, fmt.Errorf("failed to unmarshal value from cache: %w", unmarshalErr) + } + + items[key] = item + } + + return items, nil +} + +// Close releases any resources held by the cache. +// +// This should be called when the cache is no longer needed. For some implementations +// (like Redis), this may close network connections. The operation is safe for +// concurrent use by multiple goroutines. +// +// Returns: +// - error: nil on success, or an error if the cache operation fails +// +// Example: +// +// // Properly close the cache when done +// defer userCache.Close() +func (c *Typed[T]) Close() error { + if err := c.storage.Close(); err != nil { + return fmt.Errorf("failed to close cache: %w", err) + } + + return nil +} + +// newItem creates a new instance of type T using reflection. +// +// This is a helper function that creates a new instance of the generic type T, +// which must be a pointer type that implements the Item interface. It uses +// reflection to instantiate the type and verifies that it implements the interface. +// +// Returns: +// - T: A new instance of type T +// - error: An error if T is not a pointer type or cannot be instantiated +// +// Note: This is an internal helper function and is not intended for direct use. +func newItem[T Item]() (T, error) { + var zero T + + t := reflect.TypeOf((*T)(nil)).Elem() + if t.Kind() != reflect.Ptr { + return zero, fmt.Errorf("%w: type %s must be a pointer", ErrFailedToCreateZeroValue, t.String()) + } + + v := reflect.New(t.Elem()) + item, ok := v.Interface().(T) + if !ok { + return zero, fmt.Errorf("%w: cannot create value of type %s", ErrFailedToCreateZeroValue, t.String()) + } + + return item, nil +}