- Move token-related functions from auth.go to new session.go - Add session tracking with expiration and cleanup - Implement session list API endpoint - Update login/logout handlers to use session system - Add hourly cleanup of expired tokens and sessions
320 lines
7.2 KiB
Go
320 lines
7.2 KiB
Go
package main
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"super-frpc/postLog"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type TokenInfo struct {
|
|
Token string
|
|
CreatedAt time.Time
|
|
UserID int
|
|
}
|
|
|
|
var (
|
|
tokenMap = make(map[int]*TokenInfo)
|
|
tokenMux sync.RWMutex
|
|
tokenTTL = time.Hour
|
|
)
|
|
|
|
type Session struct {
|
|
ID string
|
|
UserID int
|
|
Username string
|
|
ExpireAt time.Time
|
|
}
|
|
|
|
var (
|
|
sessionMap = make(map[string]*Session)
|
|
sessionMux sync.RWMutex
|
|
sessionTTL = time.Hour
|
|
)
|
|
|
|
func GenerateToken(userID int) (string, error) {
|
|
randomBytes := make([]byte, 32)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
|
|
}
|
|
|
|
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
|
|
token := base64.URLEncoding.EncodeToString(hash[:])
|
|
|
|
tokenMux.Lock()
|
|
defer tokenMux.Unlock()
|
|
|
|
tokenMap[userID] = &TokenInfo{
|
|
Token: token,
|
|
CreatedAt: time.Now(),
|
|
UserID: userID,
|
|
}
|
|
|
|
postLog.Debug(fmt.Sprintf("[GenerateToken] Generated token for userID %d: %s", userID, token))
|
|
return token, nil
|
|
}
|
|
|
|
func ValidateToken(userID int, token string) error {
|
|
tokenMux.RLock()
|
|
defer tokenMux.RUnlock()
|
|
|
|
tokenInfo, exists := tokenMap[userID]
|
|
if !exists {
|
|
return fmt.Errorf("Token not found for userID %d: %s", userID, token)
|
|
}
|
|
|
|
if tokenInfo.Token != token {
|
|
return fmt.Errorf("Invalid token for userID %d: %s", userID, token)
|
|
}
|
|
|
|
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
|
|
return fmt.Errorf("Token expired for userID %d: %s", userID, token)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func RefreshToken(userID int) (string, error) {
|
|
tokenMux.Lock()
|
|
defer tokenMux.Unlock()
|
|
|
|
randomBytes := make([]byte, 32)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", fmt.Errorf("Failed to generate random bytes: %w", err)
|
|
}
|
|
|
|
hash := sha256.Sum256(append(randomBytes, []byte(fmt.Sprintf("%d", userID))...))
|
|
token := base64.URLEncoding.EncodeToString(hash[:])
|
|
|
|
tokenMap[userID] = &TokenInfo{
|
|
Token: token,
|
|
CreatedAt: time.Now(),
|
|
UserID: userID,
|
|
}
|
|
|
|
postLog.Debug(fmt.Sprintf("[RefreshToken] Refreshed token for userID %d: %s", userID, token))
|
|
return token, nil
|
|
}
|
|
|
|
func RemoveToken(userID int) {
|
|
tokenMux.Lock()
|
|
defer tokenMux.Unlock()
|
|
delete(tokenMap, userID)
|
|
postLog.Debug(fmt.Sprintf("[RemoveToken] Removed token for userID %d: %s", userID, tokenMap[userID].Token))
|
|
}
|
|
|
|
func GetTokenInfo(userID int) (*TokenInfo, error) {
|
|
tokenMux.RLock()
|
|
defer tokenMux.RUnlock()
|
|
|
|
tokenInfo, exists := tokenMap[userID]
|
|
if !exists {
|
|
return nil, fmt.Errorf("Token not found for userID %d", userID)
|
|
|
|
}
|
|
|
|
return tokenInfo, nil
|
|
}
|
|
|
|
func extractUserIDFromToken(token string) (int, error) {
|
|
tokenMux.RLock()
|
|
defer tokenMux.RUnlock()
|
|
for userID, tokenInfo := range tokenMap {
|
|
if tokenInfo.Token == token {
|
|
return userID, nil
|
|
}
|
|
}
|
|
return 0, fmt.Errorf("Invalid token: %s", token)
|
|
}
|
|
|
|
func CleanupExpiredTokens() {
|
|
tokenMux.Lock()
|
|
defer tokenMux.Unlock()
|
|
|
|
for userID, tokenInfo := range tokenMap {
|
|
if time.Since(tokenInfo.CreatedAt) > tokenTTL {
|
|
delete(tokenMap, userID)
|
|
postLog.Debug(fmt.Sprintf("[CleanupExpiredTokens] Removed expired token for userID %d: %s", userID, tokenInfo.Token))
|
|
RemoveSession(userID, "")
|
|
}
|
|
}
|
|
}
|
|
|
|
func CleanupExpiredSessions() {
|
|
sessionMux.Lock()
|
|
defer sessionMux.Unlock()
|
|
|
|
for sessionID, session := range sessionMap {
|
|
if time.Now().After(session.ExpireAt) {
|
|
delete(sessionMap, sessionID)
|
|
postLog.Debug(fmt.Sprintf("[CleanupExpiredSessions] Removed expired session %s for user [%d]%s", sessionID, session.UserID, session.Username))
|
|
}
|
|
}
|
|
}
|
|
|
|
func hashPassword(password string) (string, error) {
|
|
hash := sha256.Sum256([]byte(password))
|
|
return hex.EncodeToString(hash[:]), nil
|
|
}
|
|
|
|
func verifyPassword(password, hashedPassword string) bool {
|
|
hash, err := hashPassword(password)
|
|
if err != nil {
|
|
postLog.Error(fmt.Sprintf("[verifyPassword] Failed to hash password: %v", err))
|
|
return false
|
|
}
|
|
return hash == hashedPassword
|
|
}
|
|
|
|
func isValidPassword(password string) bool { // Validate password complexity and generate hash
|
|
if len(password) < 8 {
|
|
return false
|
|
}
|
|
|
|
hasUpper := false
|
|
hasLower := false
|
|
hasDigit := false
|
|
hasSpecial := false
|
|
|
|
specialChars := "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
|
|
|
for _, char := range password {
|
|
switch {
|
|
case char >= 'A' && char <= 'Z':
|
|
hasUpper = true
|
|
case char >= 'a' && char <= 'z':
|
|
hasLower = true
|
|
case char >= '0' && char <= '9':
|
|
hasDigit = true
|
|
case strings.ContainsRune(specialChars, char):
|
|
hasSpecial = true
|
|
}
|
|
}
|
|
|
|
return hasUpper && hasLower && hasDigit && hasSpecial
|
|
}
|
|
|
|
func ValidateTimeStamp(header http.Header) bool {
|
|
if isDebug {
|
|
return true
|
|
}
|
|
|
|
timeStampStr := header.Get("X-Timestamp")
|
|
if timeStampStr == "" {
|
|
return false
|
|
}
|
|
|
|
timeStamp, err := strconv.ParseInt(timeStampStr, 10, 64)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
currentTime := time.Now().UnixMilli()
|
|
if currentTime-timeStamp > 3000 || timeStamp-currentTime > 3000 {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func GetUserIDFromToken(token string) (int, error) {
|
|
userID, err := extractUserIDFromToken(token)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("Failed to extract userID from token: %w", err)
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
func DeleteTokenInfo(userID int) error {
|
|
tokenMux.Lock()
|
|
defer tokenMux.Unlock()
|
|
delete(tokenMap, userID)
|
|
postLog.Debug(fmt.Sprintf("[DeleteTokenInfo] Removed token for userID %d", userID))
|
|
return nil
|
|
}
|
|
|
|
func GetUsernameByID(userID int) string {
|
|
user, err := GetUserByID(userID)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return user.Username
|
|
}
|
|
|
|
func generateSessionID() string {
|
|
randomBytes := make([]byte, 16)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return fmt.Sprintf("session-%d-%s", time.Now().UnixNano(), fmt.Sprintf("%d", time.Now().Unix()))
|
|
}
|
|
return fmt.Sprintf("session-%x-%d", randomBytes, time.Now().UnixNano())
|
|
}
|
|
|
|
func JoinSession(userID int, userName string, token string) error {
|
|
sessionMux.Lock()
|
|
defer sessionMux.Unlock()
|
|
|
|
sessionID := generateSessionID()
|
|
session := &Session{
|
|
ID: sessionID,
|
|
UserID: userID,
|
|
Username: userName,
|
|
ExpireAt: time.Now().Add(sessionTTL),
|
|
}
|
|
|
|
sessionMap[sessionID] = session
|
|
postLog.Info(fmt.Sprintf("[JoinSession] User [%d]%s joined session %s with token %s", userID, userName, sessionID, token))
|
|
return nil
|
|
}
|
|
|
|
func RemoveSession(userID int, token string) error {
|
|
sessionMux.Lock()
|
|
defer sessionMux.Unlock()
|
|
|
|
var sessionIDToRemove string
|
|
for sessionID, session := range sessionMap {
|
|
if session.UserID == userID {
|
|
sessionIDToRemove = sessionID
|
|
break
|
|
}
|
|
}
|
|
|
|
if sessionIDToRemove != "" {
|
|
delete(sessionMap, sessionIDToRemove)
|
|
postLog.Info(fmt.Sprintf("[RemoveSession] Removed session %s for user [%d]%s", sessionIDToRemove, userID, GetUsernameByID(userID)))
|
|
}
|
|
|
|
tokenMux.Lock()
|
|
defer tokenMux.Unlock()
|
|
|
|
if tokenInfo, exists := tokenMap[userID]; exists {
|
|
if tokenInfo.Token == token || token == "" {
|
|
delete(tokenMap, userID)
|
|
postLog.Info(fmt.Sprintf("[RemoveSession] Removed token for user [%d]%s", userID, GetUsernameByID(userID)))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ListActiveSessions() []*Session {
|
|
sessionMux.RLock()
|
|
defer sessionMux.RUnlock()
|
|
|
|
var activeSessions []*Session
|
|
for _, session := range sessionMap {
|
|
if session.ExpireAt.After(time.Now()) {
|
|
activeSessions = append(activeSessions, session)
|
|
}
|
|
}
|
|
return activeSessions
|
|
}
|