Files
backend/auth.go
2026-02-28 23:00:57 +08:00

202 lines
4.5 KiB
Go

package main
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"strings"
"super-frpc/postLog"
"sync"
"time"
)
type TokenInfo struct {
Token string
CreatedAt time.Time
UserID int
}
var (
tokenMap = make(map[int]*TokenInfo)
tokenMux sync.RWMutex
tokenTTL = time.Hour
)
func GenerateToken(userID int) (string, error) {
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
}
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
token := base64.URLEncoding.EncodeToString(hash[:])
tokenMux.Lock()
defer tokenMux.Unlock()
tokenMap[userID] = &TokenInfo{
Token: token,
CreatedAt: time.Now(),
UserID: userID,
}
postLog.Debug(fmt.Sprintf("[GenerateToken] Generated token for userID %d: %s", userID, token))
return token, nil
}
func ValidateToken(userID int, token string) error {
tokenMux.RLock()
defer tokenMux.RUnlock()
tokenInfo, exists := tokenMap[userID]
if !exists {
return fmt.Errorf("Token not found for userID %d: %s", userID, token)
}
if tokenInfo.Token != token {
return fmt.Errorf("Invalid token for userID %d: %s", userID, token)
}
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
return fmt.Errorf("Token expired for userID %d: %s", userID, token)
}
return nil
}
func RefreshToken(userID int) (string, error) {
tokenMux.Lock()
defer tokenMux.Unlock()
randomBytes := make([]byte, 32)
_, err := rand.Read(randomBytes)
if err != nil {
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
}
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
token := base64.URLEncoding.EncodeToString(hash[:])
tokenMap[userID] = &TokenInfo{
Token: token,
CreatedAt: time.Now(),
UserID: userID,
}
postLog.Debug(fmt.Sprintf("[RefreshToken] Refreshed token for userID %d: %s", userID, token))
return token, nil
}
func RemoveToken(userID int) {
tokenMux.Lock()
defer tokenMux.Unlock()
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[RemoveToken] Removed token for userID %d: %s", userID, tokenMap[userID].Token))
}
func GetTokenInfo(userID int) (*TokenInfo, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
tokenInfo, exists := tokenMap[userID]
if !exists {
return nil, fmt.Errorf("Token not found for userID %d", userID)
}
return tokenInfo, nil
}
func extractUserIDFromToken(token string) (int, error) {
tokenMux.RLock()
defer tokenMux.RUnlock()
for userID, tokenInfo := range tokenMap {
if tokenInfo.Token == token {
// postLog.Debug(fmt.Sprintf("[extractUserIDFromToken] Extracted userID %d from token: %s", userID, token))
return userID, nil
}
}
return 0, fmt.Errorf("Invalid token: %s", token)
}
func CleanupExpiredTokens() {
tokenMux.Lock()
defer tokenMux.Unlock()
for userID, tokenInfo := range tokenMap {
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
delete(tokenMap, userID)
postLog.Debug(fmt.Sprintf("[CleanupExpiredTokens] Removed expired token for userID %d: %s", userID, tokenInfo.Token))
}
}
}
func hashPassword(password string) (string, error) {
hash := sha256.Sum256([]byte(password))
return hex.EncodeToString(hash[:]), nil
}
func verifyPassword(password, hashedPassword string) bool {
hash, err := hashPassword(password)
if err != nil {
postLog.Error(fmt.Sprintf("[verifyPassword] Failed to hash password: %v", err))
return false
}
return hash == hashedPassword
}
func isValidPassword(password string) bool { // Validate password complexity and generate hash
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
hasSpecial := false
specialChars := "!@#$%^&*()_+-=[]{}|;:,.<>?"
for _, char := range password {
switch {
case char >= 'A' && char <= 'Z':
hasUpper = true
case char >= 'a' && char <= 'z':
hasLower = true
case char >= '0' && char <= '9':
hasDigit = true
case strings.ContainsRune(specialChars, char):
hasSpecial = true
}
}
return hasUpper && hasLower && hasDigit && hasSpecial
}
func ValidateTimeStamp(header http.Header) bool {
if isDebug {
return true
}
timeStampStr := header.Get("X-Timestamp")
if timeStampStr == "" {
return false
}
timeStamp, err := strconv.ParseInt(timeStampStr, 10, 64)
if err != nil {
return false
}
currentTime := time.Now().UnixMilli()
if currentTime-timeStamp > 3000 || timeStamp-currentTime > 3000 {
return false
}
return true
}