Skip to content

Commit 69be423

Browse files
committed
[auth] separate users cache from the service
1 parent 94ff1c3 commit 69be423

File tree

6 files changed

+142
-84
lines changed

6 files changed

+142
-84
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package auth
2+
3+
import (
4+
"crypto/sha256"
5+
"encoding/hex"
6+
"fmt"
7+
"time"
8+
9+
"github.com/android-sms-gateway/server/internal/sms-gateway/models"
10+
"github.com/capcom6/go-helpers/cache"
11+
)
12+
13+
type usersCache struct {
14+
cache *cache.Cache[models.User]
15+
}
16+
17+
func newUsersCache() *usersCache {
18+
return &usersCache{
19+
cache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}),
20+
}
21+
}
22+
23+
func (c *usersCache) makeKey(username, password string) string {
24+
hash := sha256.Sum256([]byte(username + "\x00" + password))
25+
return hex.EncodeToString(hash[:])
26+
}
27+
28+
func (c *usersCache) Get(username, password string) (models.User, error) {
29+
user, err := c.cache.Get(c.makeKey(username, password))
30+
if err != nil {
31+
return models.User{}, fmt.Errorf("failed to get user from cache: %w", err)
32+
}
33+
34+
return user, nil
35+
}
36+
37+
func (c *usersCache) Set(username, password string, user models.User) error {
38+
if err := c.cache.Set(c.makeKey(username, password), user); err != nil {
39+
return fmt.Errorf("failed to cache user: %w", err)
40+
}
41+
42+
return nil
43+
}
44+
45+
func (c *usersCache) Delete(username, password string) error {
46+
if err := c.cache.Delete(c.makeKey(username, password)); err != nil {
47+
return fmt.Errorf("failed to delete user from cache: %w", err)
48+
}
49+
50+
return nil
51+
}
52+
53+
func (c *usersCache) Cleanup() {
54+
c.cache.Cleanup()
55+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package auth
2+
3+
import "errors"
4+
5+
var (
6+
ErrAuthorizationFailed = errors.New("authorization failed")
7+
)

internal/sms-gateway/modules/auth/module.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,26 @@ import (
77
"go.uber.org/zap"
88
)
99

10-
var Module = fx.Module(
11-
"auth",
12-
fx.Decorate(func(log *zap.Logger) *zap.Logger {
13-
return log.Named("auth")
14-
}),
15-
fx.Provide(New),
16-
fx.Provide(newRepository, fx.Private),
17-
fx.Invoke(func(lc fx.Lifecycle, svc *Service) {
18-
ctx, cancel := context.WithCancel(context.Background())
19-
lc.Append(fx.Hook{
20-
OnStart: func(_ context.Context) error {
21-
go svc.Run(ctx)
22-
return nil
23-
},
24-
OnStop: func(_ context.Context) error {
25-
cancel()
26-
return nil
27-
},
28-
})
29-
}),
30-
)
10+
func Module() fx.Option {
11+
return fx.Module(
12+
"auth",
13+
fx.Decorate(func(log *zap.Logger) *zap.Logger {
14+
return log.Named("auth")
15+
}),
16+
fx.Provide(New),
17+
fx.Provide(newRepository, fx.Private),
18+
fx.Invoke(func(lc fx.Lifecycle, svc *Service) {
19+
ctx, cancel := context.WithCancel(context.Background())
20+
lc.Append(fx.Hook{
21+
OnStart: func(_ context.Context) error {
22+
go svc.Run(ctx)
23+
return nil
24+
},
25+
OnStop: func(_ context.Context) error {
26+
cancel()
27+
return nil
28+
},
29+
})
30+
}),
31+
)
32+
}

internal/sms-gateway/modules/auth/repository.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,22 @@ func newRepository(db *gorm.DB) *repository {
1616
}
1717

1818
// GetByID returns a user by their ID.
19-
func (r *repository) GetByID(id string) (models.User, error) {
20-
user := models.User{}
19+
func (r *repository) GetByID(id string) (*models.User, error) {
20+
user := new(models.User)
2121

22-
return user, r.db.Where("id = ?", id).Take(&user).Error
22+
return user, r.db.Where("id = ?", id).Take(user).Error
2323
}
2424

25-
func (r *repository) GetByLogin(login string) (models.User, error) {
26-
user := models.User{}
25+
func (r *repository) GetByLogin(login string) (*models.User, error) {
26+
user := new(models.User)
2727

28-
return user, r.db.Where("id = ?", login).Take(&user).Error
28+
return user, r.db.Where("id = ?", login).Take(user).Error
2929
}
3030

3131
func (r *repository) Insert(user *models.User) error {
3232
return r.db.Create(user).Error
3333
}
3434

3535
func (r *repository) UpdatePassword(userID string, passwordHash string) error {
36-
return r.db.Model(&models.User{}).Where("id = ?", userID).Update("password_hash", passwordHash).Error
36+
return r.db.Model((*models.User)(nil)).Where("id = ?", userID).Update("password_hash", passwordHash).Error
3737
}

internal/sms-gateway/modules/auth/service.go

Lines changed: 48 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@ package auth
33
import (
44
"context"
55
"crypto/rand"
6-
"crypto/sha256"
76
"crypto/subtle"
8-
"encoding/hex"
97
"fmt"
108
"time"
119

@@ -41,7 +39,7 @@ type Service struct {
4139

4240
users *repository
4341
codesCache *cache.Cache[string]
44-
usersCache *cache.Cache[models.User]
42+
usersCache *usersCache
4543

4644
devicesSvc *devices.Service
4745
onlineSvc online.Service
@@ -52,7 +50,8 @@ type Service struct {
5250
}
5351

5452
func New(params Params) *Service {
55-
idgen, _ := nanoid.Standard(21)
53+
const idLen = 21
54+
idgen, _ := nanoid.Standard(idLen)
5655

5756
return &Service{
5857
config: params.Config,
@@ -62,24 +61,26 @@ func New(params Params) *Service {
6261
logger: params.Logger,
6362
idgen: idgen,
6463

65-
codesCache: cache.New[string](cache.Config{}),
66-
usersCache: cache.New[models.User](cache.Config{TTL: 1 * time.Hour}),
64+
codesCache: cache.New[string](cache.Config{TTL: codeTTL}),
65+
usersCache: newUsersCache(),
6766
}
6867
}
6968

70-
// GenerateUserCode generates a unique one-time user authorization code
71-
func (s *Service) GenerateUserCode(userID string) (AuthCode, error) {
69+
// GenerateUserCode generates a unique one-time user authorization code.
70+
func (s *Service) GenerateUserCode(userID string) (OneTimeCode, error) {
7271
var code string
7372
var err error
7473

75-
b := make([]byte, 3)
74+
const bytesLen = 3
75+
const maxCode = 1000000
76+
b := make([]byte, bytesLen)
7677
validUntil := time.Now().Add(codeTTL)
7778
for range 3 {
7879
if _, err = rand.Read(b); err != nil {
7980
continue
8081
}
81-
num := (int(b[0]) << 16) | (int(b[1]) << 8) | int(b[2])
82-
code = fmt.Sprintf("%06d", num%1000000)
82+
num := (int(b[0]) << 16) | (int(b[1]) << 8) | int(b[2]) //nolint:mnd //bitshift
83+
code = fmt.Sprintf("%06d", num%maxCode)
8384

8485
if err = s.codesCache.SetOrFail(code, userID, cache.WithValidUntil(validUntil)); err != nil {
8586
continue
@@ -89,36 +90,34 @@ func (s *Service) GenerateUserCode(userID string) (AuthCode, error) {
8990
}
9091

9192
if err != nil {
92-
return AuthCode{}, fmt.Errorf("can't generate code: %w", err)
93+
return OneTimeCode{}, fmt.Errorf("failed to generate code: %w", err)
9394
}
9495

95-
return AuthCode{Code: code, ValidUntil: validUntil}, nil
96+
return OneTimeCode{Code: code, ValidUntil: validUntil}, nil
9697
}
9798

98-
func (s *Service) RegisterUser(login, password string) (models.User, error) {
99-
user := models.User{
100-
ID: login,
101-
}
102-
103-
var err error
104-
if user.PasswordHash, err = crypto.MakeBCryptHash(password); err != nil {
105-
return user, fmt.Errorf("can't hash password: %w", err)
99+
func (s *Service) RegisterUser(login, password string) (*models.User, error) {
100+
passwordHash, err := crypto.MakeBCryptHash(password)
101+
if err != nil {
102+
return nil, fmt.Errorf("failed to hash password: %w", err)
106103
}
107104

108-
if err = s.users.Insert(&user); err != nil {
109-
return user, fmt.Errorf("can't create user")
105+
user := models.NewUser(login, passwordHash)
106+
if err = s.users.Insert(user); err != nil {
107+
return user, fmt.Errorf("failed to create user: %w", err)
110108
}
111109

112110
return user, nil
113111
}
114112

115-
func (s *Service) RegisterDevice(user models.User, name, pushToken *string) (models.Device, error) {
116-
device := models.Device{
117-
Name: name,
118-
PushToken: pushToken,
113+
func (s *Service) RegisterDevice(user *models.User, name, pushToken *string) (*models.Device, error) {
114+
device := models.NewDevice(name, pushToken)
115+
116+
if err := s.devicesSvc.Insert(user.ID, device); err != nil {
117+
return device, fmt.Errorf("failed to create device: %w", err)
119118
}
120119

121-
return device, s.devicesSvc.Insert(user.ID, &device)
120+
return device, nil
122121
}
123122

124123
func (s *Service) IsPublic() bool {
@@ -134,17 +133,18 @@ func (s *Service) AuthorizeRegistration(token string) error {
134133
return nil
135134
}
136135

137-
return fmt.Errorf("invalid token")
136+
return ErrAuthorizationFailed
138137
}
139138

140139
func (s *Service) AuthorizeDevice(token string) (models.Device, error) {
141140
device, err := s.devicesSvc.GetByToken(token)
142141
if err != nil {
143-
return device, err
142+
return device, fmt.Errorf("%w: %w", ErrAuthorizationFailed, err)
144143
}
145144

146145
go func(id string) {
147-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
146+
const timeout = 5 * time.Second
147+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
148148
defer cancel()
149149
s.onlineSvc.SetOnline(ctx, id)
150150
}(device.ID)
@@ -154,41 +154,37 @@ func (s *Service) AuthorizeDevice(token string) (models.Device, error) {
154154
return device, nil
155155
}
156156

157-
func (s *Service) AuthorizeUser(username, password string) (models.User, error) {
158-
hash := sha256.Sum256([]byte(username + password))
159-
cacheKey := hex.EncodeToString(hash[:])
160-
161-
user, err := s.usersCache.Get(cacheKey)
162-
if err == nil {
163-
return user, nil
157+
func (s *Service) AuthorizeUser(username, password string) (*models.User, error) {
158+
if user, err := s.usersCache.Get(username, password); err == nil {
159+
return &user, nil
164160
}
165161

166-
user, err = s.users.GetByLogin(username)
162+
user, err := s.users.GetByLogin(username)
167163
if err != nil {
168164
return user, err
169165
}
170166

171-
if err := crypto.CompareBCryptHash(user.PasswordHash, password); err != nil {
172-
return models.User{}, err
167+
if cmpErr := crypto.CompareBCryptHash(user.PasswordHash, password); cmpErr != nil {
168+
return nil, fmt.Errorf("password is incorrect: %w", cmpErr)
173169
}
174170

175-
if err := s.usersCache.Set(cacheKey, user); err != nil {
176-
s.logger.Error("can't cache user", zap.Error(err))
171+
if setErr := s.usersCache.Set(username, password, *user); setErr != nil {
172+
s.logger.Error("failed to cache user", zap.Error(setErr))
177173
}
178174

179175
return user, nil
180176
}
181177

182178
// AuthorizeUserByCode authorizes a user by one-time code.
183-
func (s *Service) AuthorizeUserByCode(code string) (models.User, error) {
179+
func (s *Service) AuthorizeUserByCode(code string) (*models.User, error) {
184180
userID, err := s.codesCache.GetAndDelete(code)
185181
if err != nil {
186-
return models.User{}, err
182+
return nil, fmt.Errorf("failed to get user by code: %w", err)
187183
}
188184

189185
user, err := s.users.GetByID(userID)
190186
if err != nil {
191-
return models.User{}, err
187+
return nil, err
192188
}
193189

194190
return user, nil
@@ -200,24 +196,22 @@ func (s *Service) ChangePassword(userID string, currentPassword string, newPassw
200196
return fmt.Errorf("failed to get user: %w", err)
201197
}
202198

203-
if err := crypto.CompareBCryptHash(user.PasswordHash, currentPassword); err != nil {
204-
return fmt.Errorf("current password is incorrect: %w", err)
199+
if hashErr := crypto.CompareBCryptHash(user.PasswordHash, currentPassword); hashErr != nil {
200+
return fmt.Errorf("current password is incorrect: %w", hashErr)
205201
}
206202

207203
newHash, err := crypto.MakeBCryptHash(newPassword)
208204
if err != nil {
209205
return fmt.Errorf("failed to hash new password: %w", err)
210206
}
211207

212-
if err := s.users.UpdatePassword(userID, newHash); err != nil {
213-
return fmt.Errorf("failed to update password: %w", err)
208+
if updErr := s.users.UpdatePassword(userID, newHash); updErr != nil {
209+
return fmt.Errorf("failed to update password: %w", updErr)
214210
}
215211

216212
// Invalidate cache
217-
hash := sha256.Sum256([]byte(userID + currentPassword))
218-
cacheKey := hex.EncodeToString(hash[:])
219-
if err := s.usersCache.Delete(cacheKey); err != nil {
220-
s.logger.Error("can't invalidate user cache", zap.Error(err))
213+
if delErr := s.usersCache.Delete(userID, currentPassword); delErr != nil {
214+
s.logger.Error("failed to invalidate user cache", zap.Error(delErr))
221215
}
222216

223217
return nil

internal/sms-gateway/modules/auth/types.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ const (
1111
ModePrivate Mode = "private"
1212
)
1313

14-
// AuthCode is a one-time user authorization code
15-
type AuthCode struct {
14+
// OneTimeCode is a one-time user authorization code.
15+
type OneTimeCode struct {
1616
Code string
1717
ValidUntil time.Time
1818
}

0 commit comments

Comments
 (0)