Skip to content

Commit

Permalink
fix(queries/user.go): thread safe invitation code creation
Browse files Browse the repository at this point in the history
  • Loading branch information
AstatineAi committed May 23, 2024
1 parent b0b35e4 commit bccfdda
Showing 1 changed file with 46 additions and 29 deletions.
75 changes: 46 additions & 29 deletions pkg/queries/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"fmt"
"math/rand"
"strings"
"sync"
"time"
"unicode"

Expand Down Expand Up @@ -130,7 +131,7 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error {
u.InvitedByUserID = inviter.ID
// TODO: only once for the inviter?
inviter.Reward += 100
db.Save(inviter)
db.Select("reward").Save(inviter)
}

// 检查邮箱是否已存在
Expand All @@ -154,15 +155,19 @@ func Register(db *gorm.DB, u *models.User, invitation_code string) error {
u.IsActive = false
u.IsAdmin = false

code, err := createInvitationCode(db)
err = db.Transaction(func(tx *gorm.DB) error {
if err := tx.Create(u).Error; err != nil {
return errors.Wrap(err, errors.DatabaseError)
}
if err := createInvitationCode(tx, u); err != nil {
return err
}
return nil
})

if err != nil {
return err
}
u.InvitationCode = code

if err = db.Create(u).Error; err != nil {
return errors.Wrap(err, errors.DatabaseError)
}

body := fmt.Sprintf(`<html><body>
<h1>欢迎注册%s</h1> <p>我们已经接收到您的电子邮箱验证申请,请点击以下链接完成注册。</p>
Expand Down Expand Up @@ -254,13 +259,7 @@ func Login(db *gorm.DB, email, password string) (*models.User, error) {
}

if user.InvitationCode == "" {
code, err := createInvitationCode(db)
if err != nil {
return nil, err
}

user.InvitationCode = code
err = db.Select("invitation_code").Save(user).Error
err := createInvitationCode(db, user)
if err != nil {
return nil, errors.Wrap(err, errors.DatabaseError)
}
Expand Down Expand Up @@ -468,25 +467,43 @@ func CheckInvitationCode(code string) bool {
return true
}

func createInvitationCode(db *gorm.DB) (string, error) {
// try a few times before giving up
for i := 0; i < 5; i++ {
codeRunes := make([]rune, 0, 5)
for i := 0; i < 5; i++ {
codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)])
}
code := string(codeRunes)
var invitationCodeMutex sync.Mutex

_, err := GetUserByInvitationCode(db, code)
if err != nil {
if errors.Is(err, errors.UserNotExists) {
return code, nil
func createInvitationCode(db *gorm.DB, user *models.User) error {
if db == nil {
db = database.GetDB()
}
if user.InvitationCode != "" {
return nil
}
// try a few times to generate an unique code
return db.Transaction(func(tx *gorm.DB) error {
invitationCodeMutex.Lock()
defer invitationCodeMutex.Unlock()

for i := 0; i < 10; i++ {
code := generateInvitationCode()
_, err := GetUserByInvitationCode(tx, code)
if err != nil {
if errors.Is(err, errors.UserNotExists) {
user.InvitationCode = code
return tx.Select("invitation_code").Save(user).Error
}
return err
}
return "", err
}
}

return "", errors.New(errors.InternalServerError)
return errors.New(errors.InternalServerError)
})
}

func generateInvitationCode() string {
// genetate random code with length 5, only contains [A-Za-z0-9]
codeRunes := make([]rune, 0, 5)
for i := 0; i < 5; i++ {
codeRunes = append(codeRunes, []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")[rand.Intn(62)])
}
return string(codeRunes)
}

func GetUserByInvitationCode(db *gorm.DB, code string) (*models.User, error) {
Expand Down

0 comments on commit bccfdda

Please sign in to comment.